Skip to main content

foundation_models/
async_api.rs

1//! Executor-agnostic async API for `FoundationModels` (Tier 1).
2//!
3//! Enabled with the `async` Cargo feature.  Works with any async runtime
4//! (Tokio, async-std, smol, pollster, …) because it uses only `std` types
5//! internally.
6//!
7//! ## Wrapped Apple APIs
8//!
9//! | Rust type | Apple API | Notes |
10//! |-----------|-----------|-------|
11//! | [`AsyncSession::respond`] | `LanguageModelSession.respond(to:)` | Returns `SessionResponse<String>` |
12//! | [`AsyncSession::respond_generating`] | `LanguageModelSession.respond(to:generating:)` | Returns `SessionResponse<GeneratedContent>` |
13//! | [`AsyncAdapter::from_name`] | `SystemLanguageModel.Adapter init(name:)` | Returns `Adapter` |
14//! | [`AsyncAdapter::compatibility`] | `SystemLanguageModel.Adapter.compatibility(for:)` | Returns `Vec<String>` |
15//! | [`AsyncAdapter::compile`] | `SystemLanguageModel.Adapter.compile()` | Returns `()` |
16//!
17//! ## Tier 2 note
18//!
19//! `LanguageModelSession.streamResponse(to:)` is an `AsyncSequence` — a
20//! multi-fire stream, not a one-shot future.  It is deferred to **Tier 2**
21//! (stream pattern).  Use [`crate::LanguageModelSession::stream`] for
22//! synchronous streaming in the meantime.
23//!
24//! ## Example
25//!
26//! ```rust,no_run
27//! use foundation_models::{LanguageModelSession, SystemLanguageModel};
28//! use foundation_models::async_api::AsyncSession;
29//!
30//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
31//! if !SystemLanguageModel::is_available() {
32//!     eprintln!("SKIP: FoundationModels unavailable");
33//!     return Ok(());
34//! }
35//! pollster::block_on(async {
36//!     let session = LanguageModelSession::new();
37//!     let async_session = AsyncSession::new(&session);
38//!     let reply = async_session.respond("Name three Norse gods.")?.await?;
39//!     println!("{}", reply.content);
40//!     Ok::<(), Box<dyn std::error::Error>>(())
41//! })
42//! # }
43//! ```
44
45use std::ffi::{c_void, CStr, CString};
46use std::future::Future;
47use std::pin::Pin;
48use std::task::{Context, Poll};
49
50use doom_fish_utils::completion::{error_from_cstr, AsyncCompletion, AsyncCompletionFuture};
51use serde::Deserialize;
52
53use crate::content::{BridgeGeneratedContent, GeneratedContent};
54use crate::error::FMError;
55use crate::ffi;
56use crate::generation::GenerationOptions;
57use crate::model::Adapter;
58use crate::prompt::{Prompt, ToPrompt};
59use crate::schema::GenerationSchema;
60use crate::session::{decode_bridge_text_response, SessionResponse};
61use crate::transcript::Transcript;
62
63// ============================================================================
64// Private bridge structs – mirror the JSON shapes emitted by SessionExtras.swift
65// ============================================================================
66
67#[derive(Debug, Deserialize)]
68struct AsyncBridgeStructuredResponse {
69    content: BridgeGeneratedContent,
70    #[serde(rename = "rawContent")]
71    raw_content: BridgeGeneratedContent,
72    #[serde(rename = "transcriptJSON")]
73    transcript_json: String,
74}
75
76// ============================================================================
77// Opaque pointer newtype – needed so AsyncCompletion<OpaquePtr> is Send
78// ============================================================================
79
80/// Thin Send-able wrapper around a raw opaque pointer returned by Swift.
81///
82/// # Safety
83///
84/// The pointer is a retained `AdapterBox` produced by
85/// `Unmanaged.passRetained(…).toOpaque()` on the Swift side.  We only
86/// ever pass it back to `fm_object_release`; we never dereference it in
87/// Rust.  Swift's reference counting is thread-safe, so `Send` is valid.
88struct OpaquePtr(*mut c_void);
89// SAFETY: See doc comment above.
90unsafe impl Send for OpaquePtr {}
91
92// ============================================================================
93// Callback: `FmRespondCallback` (4-arg) → AsyncCompletion<String>
94//
95// Reuses the existing `fm_session_respond_request_json` FFI which already
96// runs `try await session.respond(…)` inside a Swift Task.
97// ============================================================================
98
99/// Async respond callback.  Matches `ffi::FmRespondCallback`.
100///
101/// On success copies the JSON response to an owned `String` and completes
102/// the `AsyncCompletion`.  On failure maps the status + error to an
103/// `FMError` message string (the Future newtypes re-map that to `FMError`).
104///
105/// # Safety
106///
107/// `ctx` must be a valid `AsyncCompletion<String>` context pointer.
108/// `response` and `error` are nullable C strings owned by the Swift bridge.
109unsafe extern "C" fn respond_async_cb(
110    ctx: *mut c_void,
111    response: *mut std::ffi::c_char,
112    error: *mut std::ffi::c_char,
113    status: i32,
114) {
115    if status == ffi::status::OK && !response.is_null() {
116        let s = unsafe { CStr::from_ptr(response) }
117            .to_string_lossy()
118            .into_owned();
119        unsafe { ffi::fm_string_free(response) };
120        unsafe { AsyncCompletion::complete_ok(ctx, s) };
121    } else {
122        // Re-use the existing from_swift error mapper; convert FMError to String
123        // so we can store it in AsyncCompletion<String>.
124        let fm_err = crate::error::from_swift(status, error);
125        unsafe { AsyncCompletion::<String>::complete_err(ctx, fm_err.to_string()) };
126    }
127}
128
129// ============================================================================
130// Callback: 3-arg async callback → AsyncCompletion<OpaquePtr>
131//
132// Used by fm_adapter_create_from_name_async.
133// ============================================================================
134
135/// # Safety
136///
137/// `ctx` must be a valid `AsyncCompletion<OpaquePtr>` context pointer.
138unsafe extern "C" fn adapter_init_async_cb(
139    result: *mut c_void,
140    error: *const std::ffi::c_char,
141    ctx: *mut c_void,
142) {
143    if !error.is_null() {
144        let msg = unsafe { error_from_cstr(error) };
145        unsafe { AsyncCompletion::<OpaquePtr>::complete_err(ctx, msg) };
146    } else if !result.is_null() {
147        unsafe { AsyncCompletion::complete_ok(ctx, OpaquePtr(result)) };
148    } else {
149        unsafe { AsyncCompletion::<OpaquePtr>::complete_err(ctx, "null adapter pointer".into()) };
150    }
151}
152
153// ============================================================================
154// Callback: 3-arg async callback → AsyncCompletion<String>
155//
156// Used by fm_adapter_compatibility_async.  The result pointer is a strdup'd
157// JSON string; we copy it and free it.
158// ============================================================================
159
160/// # Safety
161///
162/// `ctx` must be a valid `AsyncCompletion<String>` context pointer.
163/// `result` (when non-null) must be a heap-allocated C string freed with
164/// `fm_string_free`.
165unsafe extern "C" fn adapter_compat_async_cb(
166    result: *mut c_void,
167    error: *const std::ffi::c_char,
168    ctx: *mut c_void,
169) {
170    if !error.is_null() {
171        let msg = unsafe { error_from_cstr(error) };
172        unsafe { AsyncCompletion::<String>::complete_err(ctx, msg) };
173    } else if !result.is_null() {
174        let s = unsafe { CStr::from_ptr(result.cast::<std::ffi::c_char>()) }
175            .to_string_lossy()
176            .into_owned();
177        // Free the strdup'd JSON string allocated by the Swift bridge.
178        unsafe { ffi::fm_string_free(result.cast::<std::ffi::c_char>()) };
179        unsafe { AsyncCompletion::complete_ok(ctx, s) };
180    } else {
181        unsafe { AsyncCompletion::<String>::complete_err(ctx, "null compatibility result".into()) };
182    }
183}
184
185/// # Safety
186///
187/// `ctx` must be a valid `AsyncCompletion<()>` context pointer.
188unsafe extern "C" fn adapter_compile_async_cb(
189    context: *mut c_void,
190    response: *mut std::ffi::c_char,
191    error: *mut std::ffi::c_char,
192    status: i32,
193) {
194    if !response.is_null() {
195        unsafe { ffi::fm_string_free(response) };
196    }
197
198    if status == ffi::status::OK {
199        if !error.is_null() {
200            unsafe { ffi::fm_string_free(error) };
201        }
202        unsafe { AsyncCompletion::complete_ok(context, ()) };
203        return;
204    }
205
206    let message = if error.is_null() {
207        crate::error::from_swift(status, std::ptr::null_mut()).to_string()
208    } else {
209        let message = unsafe { CStr::from_ptr(error) }.to_string_lossy().into_owned();
210        unsafe { ffi::fm_string_free(error) };
211        message
212    };
213    unsafe { AsyncCompletion::<()>::complete_err(context, message) };
214}
215
216// ============================================================================
217// RespondFuture — LanguageModelSession.respond(to:)
218// ============================================================================
219
220/// Future returned by [`AsyncSession::respond`].
221///
222/// Resolves to `Result<SessionResponse<String>, FMError>`.
223pub struct RespondFuture {
224    inner: AsyncCompletionFuture<String>,
225}
226
227impl std::fmt::Debug for RespondFuture {
228    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229        f.debug_struct("RespondFuture").finish_non_exhaustive()
230    }
231}
232
233impl Future for RespondFuture {
234    type Output = Result<SessionResponse<String>, FMError>;
235
236    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237        Pin::new(&mut self.inner).poll(cx).map(|r| {
238            r.map_err(|msg| FMError::Unknown {
239                code: ffi::status::UNKNOWN,
240                message: msg,
241            })
242            .and_then(|json| decode_bridge_text_response(&json))
243        })
244    }
245}
246
247// ============================================================================
248// RespondGeneratingFuture — LanguageModelSession.respond(to:generating:)
249// ============================================================================
250
251/// Future returned by [`AsyncSession::respond_generating`].
252///
253/// Resolves to `Result<SessionResponse<GeneratedContent>, FMError>`.
254pub struct RespondGeneratingFuture {
255    inner: AsyncCompletionFuture<String>,
256}
257
258impl std::fmt::Debug for RespondGeneratingFuture {
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        f.debug_struct("RespondGeneratingFuture")
261            .finish_non_exhaustive()
262    }
263}
264
265impl Future for RespondGeneratingFuture {
266    type Output = Result<SessionResponse<GeneratedContent>, FMError>;
267
268    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
269        Pin::new(&mut self.inner).poll(cx).map(|r| {
270            r.map_err(|msg| FMError::Unknown {
271                code: ffi::status::UNKNOWN,
272                message: msg,
273            })
274            .and_then(|json| {
275                let response: AsyncBridgeStructuredResponse = serde_json::from_str(&json)
276                    .map_err(|e| FMError::DecodingFailure(e.to_string()))?;
277                Ok(SessionResponse {
278                    content: GeneratedContent::from_bridge_payload(response.content, true)?,
279                    raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
280                    transcript: Transcript::from_json_str(&response.transcript_json)?,
281                })
282            })
283        })
284    }
285}
286
287// ============================================================================
288// AdapterInitFuture — SystemLanguageModel.Adapter init(name:)
289// ============================================================================
290
291/// Future returned by [`AsyncAdapter::from_name`].
292///
293/// Resolves to `Result<Adapter, FMError>`.
294pub struct AdapterInitFuture {
295    inner: AsyncCompletionFuture<OpaquePtr>,
296}
297
298impl std::fmt::Debug for AdapterInitFuture {
299    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        f.debug_struct("AdapterInitFuture").finish_non_exhaustive()
301    }
302}
303
304impl Future for AdapterInitFuture {
305    type Output = Result<Adapter, FMError>;
306
307    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
308        Pin::new(&mut self.inner).poll(cx).map(|r| {
309            r.map_err(FMError::AdapterInvalidName)
310                .map(|OpaquePtr(ptr)| Adapter { ptr })
311        })
312    }
313}
314
315// ============================================================================
316// AdapterCompatibilityFuture — SystemLanguageModel.Adapter.compatibility(for:)
317// ============================================================================
318
319/// Future returned by [`AsyncAdapter::compatibility`].
320///
321/// Resolves to `Result<Vec<String>, FMError>`.
322pub struct AdapterCompatibilityFuture {
323    inner: AsyncCompletionFuture<String>,
324}
325
326impl std::fmt::Debug for AdapterCompatibilityFuture {
327    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328        f.debug_struct("AdapterCompatibilityFuture")
329            .finish_non_exhaustive()
330    }
331}
332
333impl Future for AdapterCompatibilityFuture {
334    type Output = Result<Vec<String>, FMError>;
335
336    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
337        Pin::new(&mut self.inner).poll(cx).map(|r| {
338            r.map_err(FMError::AdapterCompatibleNotFound)
339                .and_then(|json| {
340                    serde_json::from_str::<Vec<String>>(&json)
341                        .map_err(|e| FMError::DecodingFailure(e.to_string()))
342                })
343        })
344    }
345}
346
347/// Future returned by [`AsyncAdapter::compile`].
348///
349/// Resolves to `Result<(), FMError>`.
350pub struct CompileAdapterFuture {
351    inner: AsyncCompletionFuture<()>,
352}
353
354impl std::fmt::Debug for CompileAdapterFuture {
355    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356        f.debug_struct("CompileAdapterFuture")
357            .finish_non_exhaustive()
358    }
359}
360
361impl Future for CompileAdapterFuture {
362    type Output = Result<(), FMError>;
363
364    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
365        Pin::new(&mut self.inner).poll(cx).map(|r| {
366            r.map_err(|message| FMError::Unknown {
367                code: ffi::status::UNKNOWN,
368                message,
369            })
370        })
371    }
372}
373
374// ============================================================================
375// AsyncSession — async wrapper around LanguageModelSession
376// ============================================================================
377
378/// Async wrapper around [`crate::LanguageModelSession`].
379///
380/// Borrows the session for its lifetime; the session itself must outlive all
381/// in-flight futures.
382///
383/// # Examples
384///
385/// ```rust,no_run
386/// use foundation_models::{LanguageModelSession, SystemLanguageModel};
387/// use foundation_models::async_api::AsyncSession;
388///
389/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
390/// if !SystemLanguageModel::is_available() { return Ok(()); }
391/// pollster::block_on(async {
392///     let session = LanguageModelSession::new();
393///     let reply = AsyncSession::new(&session).respond("Hi!")?.await?;
394///     println!("{}", reply.content);
395///     Ok::<(), Box<dyn std::error::Error>>(())
396/// })
397/// # }
398/// ```
399pub struct AsyncSession<'s> {
400    session: &'s crate::session::LanguageModelSession,
401}
402
403impl<'s> AsyncSession<'s> {
404    /// Wrap a [`crate::LanguageModelSession`] for async use.
405    #[must_use]
406    pub fn new(session: &'s crate::session::LanguageModelSession) -> Self {
407        Self { session }
408    }
409
410    /// Async version of `LanguageModelSession.respond(to:)`.
411    ///
412    /// Corresponds to the Swift `async throws` method
413    /// `LanguageModelSession.respond(to:)`.
414    ///
415    /// # Errors
416    ///
417    /// Returns an [`FMError`] if the model is unavailable or generation fails.
418    pub fn respond(&self, prompt: impl ToPrompt) -> Result<RespondFuture, FMError> {
419        let prompt = prompt.to_prompt()?;
420        let payload = build_text_request_json(&prompt, GenerationOptions::new())?;
421        let session_ptr = self.session.as_ptr();
422        let (future, ctx) = AsyncCompletion::create();
423        unsafe {
424            ffi::fm_session_respond_request_json(
425                session_ptr,
426                payload.as_ptr(),
427                ctx,
428                respond_async_cb,
429            );
430        }
431        Ok(RespondFuture { inner: future })
432    }
433
434    /// Async version of `LanguageModelSession.respond(to:)` with [`GenerationOptions`].
435    ///
436    /// # Errors
437    ///
438    /// Returns an [`FMError`] if the model is unavailable or generation fails.
439    pub fn respond_with_options(
440        &self,
441        prompt: impl ToPrompt,
442        options: GenerationOptions,
443    ) -> Result<RespondFuture, FMError> {
444        let prompt = prompt.to_prompt()?;
445        let payload = build_text_request_json(&prompt, options)?;
446        let session_ptr = self.session.as_ptr();
447        let (future, ctx) = AsyncCompletion::create();
448        unsafe {
449            ffi::fm_session_respond_request_json(
450                session_ptr,
451                payload.as_ptr(),
452                ctx,
453                respond_async_cb,
454            );
455        }
456        Ok(RespondFuture { inner: future })
457    }
458
459    /// Async version of `LanguageModelSession.respond(to:generating:)`.
460    ///
461    /// Generates a structured `GeneratedContent` response according to
462    /// `schema`.  Corresponds to the Swift `async throws` method
463    /// `LanguageModelSession.respond(to:generating:)`.
464    ///
465    /// # Errors
466    ///
467    /// Returns an [`FMError`] if the model is unavailable or generation fails.
468    pub fn respond_generating(
469        &self,
470        prompt: impl ToPrompt,
471        schema: &GenerationSchema,
472        include_schema_in_prompt: bool,
473        options: GenerationOptions,
474    ) -> Result<RespondGeneratingFuture, FMError> {
475        let prompt = prompt.to_prompt()?;
476        let payload =
477            build_structured_request_json(&prompt, options, schema, include_schema_in_prompt)?;
478        let session_ptr = self.session.as_ptr();
479        let (future, ctx) = AsyncCompletion::create();
480        unsafe {
481            ffi::fm_session_respond_request_json(
482                session_ptr,
483                payload.as_ptr(),
484                ctx,
485                respond_async_cb,
486            );
487        }
488        Ok(RespondGeneratingFuture { inner: future })
489    }
490}
491
492impl std::fmt::Debug for AsyncSession<'_> {
493    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494        f.debug_struct("AsyncSession").finish_non_exhaustive()
495    }
496}
497
498// ============================================================================
499// AsyncAdapter — async adapter lifecycle
500// ============================================================================
501
502/// Namespace for async [`Adapter`] operations.
503///
504/// # Examples
505///
506/// ```rust,no_run
507/// use foundation_models::async_api::AsyncAdapter;
508///
509/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
510/// pollster::block_on(async {
511///     let ids = AsyncAdapter::compatibility("com.example.MyAdapter")?.await?;
512///     println!("compatible: {ids:?}");
513///     Ok::<(), Box<dyn std::error::Error>>(())
514/// })
515/// # }
516/// ```
517pub struct AsyncAdapter;
518
519impl AsyncAdapter {
520    /// Async version of `SystemLanguageModel.Adapter init(name:)`.
521    ///
522    /// Loads the named adapter asynchronously, returning a ready-to-use
523    /// [`Adapter`] handle.
524    ///
525    /// # Errors
526    ///
527    /// Returns an [`FMError::AdapterInvalidName`] if the adapter is not found
528    /// or the name contains a NUL byte.
529    pub fn from_name(name: &str) -> Result<AdapterInitFuture, FMError> {
530        let cname = CString::new(name)
531            .map_err(|e| FMError::InvalidArgument(format!("NUL byte in adapter name: {e}")))?;
532        let (future, ctx) = AsyncCompletion::create();
533        unsafe {
534            ffi::fm_adapter_create_from_name_async(cname.as_ptr(), ctx, adapter_init_async_cb);
535        }
536        Ok(AdapterInitFuture { inner: future })
537    }
538
539    /// Async version of `SystemLanguageModel.Adapter.compatibility(for:)`.
540    ///
541    /// Returns the list of compatible adapter identifiers for the given
542    /// logical adapter name.
543    ///
544    /// # Errors
545    ///
546    /// Returns an [`FMError::AdapterCompatibleNotFound`] on failure.
547    pub fn compatibility(name: &str) -> Result<AdapterCompatibilityFuture, FMError> {
548        let cname = CString::new(name)
549            .map_err(|e| FMError::InvalidArgument(format!("NUL byte in adapter name: {e}")))?;
550        let (future, ctx) = AsyncCompletion::create();
551        unsafe {
552            ffi::fm_adapter_compatibility_async(cname.as_ptr(), ctx, adapter_compat_async_cb);
553        }
554        Ok(AdapterCompatibilityFuture { inner: future })
555    }
556
557    /// Async version of `SystemLanguageModel.Adapter.compile()`.
558    #[must_use]
559    pub fn compile(adapter: &Adapter) -> CompileAdapterFuture {
560        let (future, ctx) = AsyncCompletion::create();
561        unsafe {
562            ffi::fm_adapter_compile(adapter.ptr, ctx, adapter_compile_async_cb);
563        }
564        CompileAdapterFuture { inner: future }
565    }
566}
567
568// ============================================================================
569// Internal JSON request builders
570// ============================================================================
571
572fn build_text_request_json(
573    prompt: &Prompt,
574    options: GenerationOptions,
575) -> Result<CString, FMError> {
576    build_request_json_inner(prompt, options, None, true)
577}
578
579fn build_structured_request_json(
580    prompt: &Prompt,
581    options: GenerationOptions,
582    schema: &GenerationSchema,
583    include_schema_in_prompt: bool,
584) -> Result<CString, FMError> {
585    build_request_json_inner(prompt, options, Some(schema), include_schema_in_prompt)
586}
587
588fn build_request_json_inner(
589    prompt: &Prompt,
590    options: GenerationOptions,
591    schema: Option<&GenerationSchema>,
592    include_schema_in_prompt: bool,
593) -> Result<CString, FMError> {
594    use crate::generation::SamplingMode;
595    use serde_json::json;
596
597    let sampling = match options.sampling() {
598        SamplingMode::Default => json!({ "mode": "default" }),
599        SamplingMode::Greedy => json!({ "mode": "greedy" }),
600        SamplingMode::TopK(k) => json!({
601            "mode": "top_k",
602            "topK": k,
603            "seed": options.sampling_seed(),
604        }),
605        SamplingMode::TopP(p) => json!({
606            "mode": "top_p",
607            "topP": p,
608            "seed": options.sampling_seed(),
609        }),
610    };
611    let include_schema_in_prompt = schema.map_or(include_schema_in_prompt, |schema| {
612        schema.effective_include_schema_in_prompt(include_schema_in_prompt)
613    });
614    let payload = serde_json::to_string(&json!({
615        "prompt": prompt.to_bridge_value(),
616        "options": {
617            "temperature": options.temperature(),
618            "maximumResponseTokens": options.maximum_response_tokens(),
619            "sampling": sampling,
620        },
621        "schemaJSON": schema.map(GenerationSchema::bridge_request_json),
622        "includeSchemaInPrompt": include_schema_in_prompt,
623    }))
624    .map_err(|e| FMError::InvalidArgument(format!("request not JSON-serializable: {e}")))?;
625    CString::new(payload)
626        .map_err(|e| FMError::InvalidArgument(format!("request JSON contains NUL: {e}")))
627}