apple-vision 0.16.5

Safe Rust bindings for Apple's Vision framework — OCR, object detection, face landmarks on macOS
Documentation
// Stateful tracking helpers plus object/rectangle trackers.

import AppKit
import CoreGraphics
import CoreImage
import CoreML
import CoreVideo
import Foundation
import ImageIO
import Vision

internal func mkRectangleRaw(_ obs: VNRectangleObservation) -> VNRectangleObservationRaw {
    let b = obs.boundingBox
    return VNRectangleObservationRaw(
        bbox_x: Double(b.origin.x),
        bbox_y: Double(b.origin.y),
        bbox_w: Double(b.size.width),
        bbox_h: Double(b.size.height),
        confidence: obs.confidence,
        tl_x: Double(obs.topLeft.x),
        tl_y: Double(obs.topLeft.y),
        tr_x: Double(obs.topRight.x),
        tr_y: Double(obs.topRight.y),
        bl_x: Double(obs.bottomLeft.x),
        bl_y: Double(obs.bottomLeft.y),
        br_x: Double(obs.bottomRight.x),
        br_y: Double(obs.bottomRight.y)
    )
}

internal func mkRectangleObservation(_ raw: VNRectangleObservationRaw) -> VNRectangleObservation {
    let tl = CGPoint(x: raw.tl_x, y: raw.tl_y)
    let tr = CGPoint(x: raw.tr_x, y: raw.tr_y)
    let br = CGPoint(x: raw.br_x, y: raw.br_y)
    let bl = CGPoint(x: raw.bl_x, y: raw.bl_y)
    if #available(macOS 14.0, *) {
        return VNRectangleObservation(
            requestRevision: VNRequestRevisionUnspecified,
            topLeft: tl,
            topRight: tr,
            bottomRight: br,
            bottomLeft: bl
        )
    }
    return VNRectangleObservation(
        requestRevision: VNRequestRevisionUnspecified,
        topLeft: tl,
        bottomLeft: bl,
        bottomRight: br,
        topRight: tr
    )
}

internal func copyTrackedPixelBuffer(_ buffer: CVPixelBuffer) -> VNSegmentationMaskRaw {
    let width = CVPixelBufferGetWidth(buffer)
    let height = CVPixelBufferGetHeight(buffer)
    let bytesPerRow = CVPixelBufferGetBytesPerRow(buffer)
    CVPixelBufferLockBaseAddress(buffer, .readOnly)
    defer { CVPixelBufferUnlockBaseAddress(buffer, .readOnly) }
    guard let base = CVPixelBufferGetBaseAddress(buffer) else {
        return VNSegmentationMaskRaw(width: width, height: height, bytes_per_row: bytesPerRow, bytes: nil)
    }
    let size = height * bytesPerRow
    let out = UnsafeMutableRawPointer.allocate(byteCount: size, alignment: 8)
    memcpy(out, base, size)
    return VNSegmentationMaskRaw(width: width, height: height, bytes_per_row: bytesPerRow, bytes: out)
}

internal func isTimestampRequirementError(_ error: Error) -> Bool {
    let msg = error.localizedDescription.lowercased()
    return msg.contains("timestamp") || msg.contains("presentationtimestamp") || msg.contains("pts")
}

@frozen
public struct VNIdentityHomographicAlignmentRaw {
    public static let value = VNHomographicAlignmentRaw(
        m00: 1, m01: 0, m02: 0,
        m10: 0, m11: 1, m12: 0,
        m20: 0, m21: 0, m22: 1,
        _pad: 0
    )
}

final class ObjectTrackerSession {
    private let handler = VNSequenceRequestHandler()
    private let request: VNTrackObjectRequest
    private var lastObservation: VNDetectedObjectObservation

    init(initialImage: CGImage, initialBoundingBox: VNSimpleRectRaw) throws {
        let rect = CGRect(x: initialBoundingBox.x, y: initialBoundingBox.y,
                          width: initialBoundingBox.w, height: initialBoundingBox.h)
        let observation = VNDetectedObjectObservation(boundingBox: rect)
        request = VNTrackObjectRequest(detectedObjectObservation: observation)
        request.trackingLevel = .accurate
        lastObservation = observation
        try handler.perform([request], on: initialImage)
        if let tracked = request.results?.first as? VNDetectedObjectObservation {
            lastObservation = tracked
            request.inputObservation = tracked
        }
    }

    func track(nextImage: CGImage) throws -> VNSimpleRectRaw {
        try handler.perform([request], on: nextImage)
        if let tracked = request.results?.first as? VNDetectedObjectObservation {
            lastObservation = tracked
            request.inputObservation = tracked
        }
        return mkRect(lastObservation.boundingBox, lastObservation.confidence)
    }
}

final class RectangleTrackerSession {
    private let handler = VNSequenceRequestHandler()
    private let request: VNTrackRectangleRequest
    private var lastObservation: VNRectangleObservation

    init(initialImage: CGImage, initialObservation: VNRectangleObservationRaw) throws {
        let rectangleObservation = mkRectangleObservation(initialObservation)
        request = VNTrackRectangleRequest(rectangleObservation: rectangleObservation)
        request.trackingLevel = .accurate
        lastObservation = rectangleObservation
        try handler.perform([request], on: initialImage)
        if let tracked = request.results?.first as? VNRectangleObservation {
            lastObservation = tracked
            request.inputObservation = tracked
        }
    }

    func track(nextImage: CGImage) throws -> VNRectangleObservationRaw {
        try handler.perform([request], on: nextImage)
        if let tracked = request.results?.first as? VNRectangleObservation {
            lastObservation = tracked
            request.inputObservation = tracked
        }
        return mkRectangleRaw(lastObservation)
    }
}

@available(macOS 14.0, *)
final class OpticalFlowTrackerSession {
    private let sequenceHandler = VNSequenceRequestHandler()
    private let request = VNTrackOpticalFlowRequest()
    private var usesImageHandlers = false

    init(referenceImage: CGImage) throws {
        request.computationAccuracy = .medium
        request.outputPixelFormat = kCVPixelFormatType_TwoComponent32Float
        try perform(on: referenceImage)
    }

    private func perform(on image: CGImage) throws {
        if usesImageHandlers {
            let handler = VNImageRequestHandler(cgImage: image, options: [:])
            try handler.perform([request])
            return
        }
        do {
            try sequenceHandler.perform([request], on: image)
        } catch {
            if isTimestampRequirementError(error) {
                usesImageHandlers = true
                let handler = VNImageRequestHandler(cgImage: image, options: [:])
                try handler.perform([request])
                return
            }
            throw error
        }
    }

    func track(nextImage: CGImage) throws -> VNSegmentationMaskRaw {
        try perform(on: nextImage)
        guard let observation = request.results?.first else {
            return VNSegmentationMaskRaw(width: 0, height: 0, bytes_per_row: 0, bytes: nil)
        }
        return copyTrackedPixelBuffer(observation.pixelBuffer)
    }
}

@available(macOS 14.0, *)
final class TranslationalImageTrackerSession {
    private let sequenceHandler = VNSequenceRequestHandler()
    private let request = VNTrackTranslationalImageRegistrationRequest()
    private var usesImageHandlers = false

    init(referenceImage: CGImage) throws {
        try perform(on: referenceImage)
    }

    private func perform(on image: CGImage) throws {
        if usesImageHandlers {
            let handler = VNImageRequestHandler(cgImage: image, options: [:])
            try handler.perform([request])
            return
        }
        do {
            try sequenceHandler.perform([request], on: image)
        } catch {
            if isTimestampRequirementError(error) {
                usesImageHandlers = true
                let handler = VNImageRequestHandler(cgImage: image, options: [:])
                try handler.perform([request])
                return
            }
            throw error
        }
    }

    func track(nextImage: CGImage) throws -> VNTranslationalAlignmentRaw {
        try perform(on: nextImage)
        guard let observation = request.results?.first else {
            return VNTranslationalAlignmentRaw(tx: 0, ty: 0)
        }
        return VNTranslationalAlignmentRaw(
            tx: Double(observation.alignmentTransform.tx),
            ty: Double(observation.alignmentTransform.ty)
        )
    }
}

@available(macOS 14.0, *)
final class HomographicImageTrackerSession {
    private let sequenceHandler = VNSequenceRequestHandler()
    private let request = VNTrackHomographicImageRegistrationRequest()
    private var usesImageHandlers = false

    init(referenceImage: CGImage) throws {
        try perform(on: referenceImage)
    }

    private func perform(on image: CGImage) throws {
        if usesImageHandlers {
            let handler = VNImageRequestHandler(cgImage: image, options: [:])
            try handler.perform([request])
            return
        }
        do {
            try sequenceHandler.perform([request], on: image)
        } catch {
            if isTimestampRequirementError(error) {
                usesImageHandlers = true
                let handler = VNImageRequestHandler(cgImage: image, options: [:])
                try handler.perform([request])
                return
            }
            throw error
        }
    }

    func track(nextImage: CGImage) throws -> VNHomographicAlignmentRaw {
        try perform(on: nextImage)
        guard let observation = request.results?.first else {
            return VNIdentityHomographicAlignmentRaw.value
        }
        let m = observation.warpTransform
        return VNHomographicAlignmentRaw(
            m00: m.columns.0.x, m01: m.columns.0.y, m02: m.columns.0.z,
            m10: m.columns.1.x, m11: m.columns.1.y, m12: m.columns.1.z,
            m20: m.columns.2.x, m21: m.columns.2.y, m22: m.columns.2.z,
            _pad: 0
        )
    }
}

@_cdecl("vn_object_tracker_create")
public func vn_object_tracker_create(
    _ initialPath: UnsafePointer<CChar>,
    _ initialBoundingBoxRaw: UnsafeMutableRawPointer?,
    _ outHandle: UnsafeMutablePointer<UnsafeMutableRawPointer?>,
    _ outErrorMessage: UnsafeMutablePointer<UnsafeMutablePointer<CChar>?>?
) -> Int32 {
    outHandle.pointee = nil
    guard let initialBoundingBoxRaw else {
        outErrorMessage?.pointee = ffiString("missing initial bounding box")
        return VN_INVALID_ARGUMENT
    }
    let initialBoundingBox = initialBoundingBoxRaw.assumingMemoryBound(to: VNSimpleRectRaw.self)
    let path = String(cString: initialPath)
    guard let image = loadCGImage(path: path) else {
        outErrorMessage?.pointee = ffiString("could not load image at \(path)")
        return VN_IMAGE_LOAD_FAILED
    }
    do {
        let tracker = try ObjectTrackerSession(initialImage: image, initialBoundingBox: initialBoundingBox.pointee)
        outHandle.pointee = Unmanaged.passRetained(tracker).toOpaque()
        return VN_OK
    } catch {
        outErrorMessage?.pointee = ffiString("object tracker create failed: \(error.localizedDescription)")
        return VN_REQUEST_FAILED
    }
}

@_cdecl("vn_object_tracker_track")
public func vn_object_tracker_track(
    _ handle: UnsafeMutableRawPointer?,
    _ nextPath: UnsafePointer<CChar>,
    _ outBoundingBoxRaw: UnsafeMutableRawPointer,
    _ outErrorMessage: UnsafeMutablePointer<UnsafeMutablePointer<CChar>?>?
) -> Int32 {
    let outBoundingBox = outBoundingBoxRaw.assumingMemoryBound(to: VNSimpleRectRaw.self)
    outBoundingBox.pointee = VNSimpleRectRaw(x: 0, y: 0, w: 0, h: 0, confidence: 0, _pad: 0)
    guard let handle else {
        outErrorMessage?.pointee = ffiString("null object tracker handle")
        return VN_INVALID_ARGUMENT
    }
    let path = String(cString: nextPath)
    guard let image = loadCGImage(path: path) else {
        outErrorMessage?.pointee = ffiString("could not load image at \(path)")
        return VN_IMAGE_LOAD_FAILED
    }
    let tracker = Unmanaged<ObjectTrackerSession>.fromOpaque(handle).takeUnretainedValue()
    do {
        outBoundingBox.pointee = try tracker.track(nextImage: image)
        return VN_OK
    } catch {
        outErrorMessage?.pointee = ffiString("object tracker track failed: \(error.localizedDescription)")
        return VN_REQUEST_FAILED
    }
}

@_cdecl("vn_object_tracker_release")
public func vn_object_tracker_release(_ handle: UnsafeMutableRawPointer?) {
    guard let handle else { return }
    Unmanaged<ObjectTrackerSession>.fromOpaque(handle).release()
}

@_cdecl("vn_rectangle_tracker_create")
public func vn_rectangle_tracker_create(
    _ initialPath: UnsafePointer<CChar>,
    _ initialObservationRaw: UnsafeMutableRawPointer?,
    _ outHandle: UnsafeMutablePointer<UnsafeMutableRawPointer?>,
    _ outErrorMessage: UnsafeMutablePointer<UnsafeMutablePointer<CChar>?>?
) -> Int32 {
    outHandle.pointee = nil
    guard let initialObservationRaw else {
        outErrorMessage?.pointee = ffiString("missing initial rectangle observation")
        return VN_INVALID_ARGUMENT
    }
    let initialObservation = initialObservationRaw.assumingMemoryBound(to: VNRectangleObservationRaw.self)
    let path = String(cString: initialPath)
    guard let image = loadCGImage(path: path) else {
        outErrorMessage?.pointee = ffiString("could not load image at \(path)")
        return VN_IMAGE_LOAD_FAILED
    }
    do {
        let tracker = try RectangleTrackerSession(initialImage: image, initialObservation: initialObservation.pointee)
        outHandle.pointee = Unmanaged.passRetained(tracker).toOpaque()
        return VN_OK
    } catch {
        outErrorMessage?.pointee = ffiString("rectangle tracker create failed: \(error.localizedDescription)")
        return VN_REQUEST_FAILED
    }
}

@_cdecl("vn_rectangle_tracker_track")
public func vn_rectangle_tracker_track(
    _ handle: UnsafeMutableRawPointer?,
    _ nextPath: UnsafePointer<CChar>,
    _ outObservationRaw: UnsafeMutableRawPointer,
    _ outErrorMessage: UnsafeMutablePointer<UnsafeMutablePointer<CChar>?>?
) -> Int32 {
    let outObservation = outObservationRaw.assumingMemoryBound(to: VNRectangleObservationRaw.self)
    outObservation.pointee = VNRectangleObservationRaw(
        bbox_x: 0, bbox_y: 0, bbox_w: 0, bbox_h: 0, confidence: 0,
        tl_x: 0, tl_y: 0, tr_x: 0, tr_y: 0,
        bl_x: 0, bl_y: 0, br_x: 0, br_y: 0
    )
    guard let handle else {
        outErrorMessage?.pointee = ffiString("null rectangle tracker handle")
        return VN_INVALID_ARGUMENT
    }
    let path = String(cString: nextPath)
    guard let image = loadCGImage(path: path) else {
        outErrorMessage?.pointee = ffiString("could not load image at \(path)")
        return VN_IMAGE_LOAD_FAILED
    }
    let tracker = Unmanaged<RectangleTrackerSession>.fromOpaque(handle).takeUnretainedValue()
    do {
        outObservation.pointee = try tracker.track(nextImage: image)
        return VN_OK
    } catch {
        outErrorMessage?.pointee = ffiString("rectangle tracker track failed: \(error.localizedDescription)")
        return VN_REQUEST_FAILED
    }
}

@_cdecl("vn_rectangle_tracker_release")
public func vn_rectangle_tracker_release(_ handle: UnsafeMutableRawPointer?) {
    guard let handle else { return }
    Unmanaged<RectangleTrackerSession>.fromOpaque(handle).release()
}