Skip to main content

foundation_models/
async_api.rs

1//! Executor-agnostic async API for `FoundationModels` (Tier 1).
2//!
3//! Enabled with the `async` Cargo feature.  Works with any async runtime
4//! (Tokio, async-std, smol, pollster, …) because it uses only `std` types
5//! internally.
6//!
7//! ## Wrapped Apple APIs
8//!
9//! | Rust type | Apple API | Notes |
10//! |-----------|-----------|-------|
11//! | [`AsyncSession::respond`] | `LanguageModelSession.respond(to:)` | Returns `SessionResponse<String>` |
12//! | [`AsyncSession::respond_generating`] | `LanguageModelSession.respond(to:generating:)` | Returns `SessionResponse<GeneratedContent>` |
13//! | [`AsyncAdapter::from_name`] | `SystemLanguageModel.Adapter init(name:)` | Returns `Adapter` |
14//! | [`AsyncAdapter::compatibility`] | `SystemLanguageModel.Adapter.compatibility(for:)` | Returns `Vec<String>` |
15//!
16//! ## Tier 2 note
17//!
18//! `LanguageModelSession.streamResponse(to:)` is an `AsyncSequence` — a
19//! multi-fire stream, not a one-shot future.  It is deferred to **Tier 2**
20//! (stream pattern).  Use [`crate::LanguageModelSession::stream`] for
21//! synchronous streaming in the meantime.
22//!
23//! ## Example
24//!
25//! ```rust,no_run
26//! use foundation_models::{LanguageModelSession, SystemLanguageModel};
27//! use foundation_models::async_api::AsyncSession;
28//!
29//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
30//! if !SystemLanguageModel::is_available() {
31//!     eprintln!("SKIP: FoundationModels unavailable");
32//!     return Ok(());
33//! }
34//! pollster::block_on(async {
35//!     let session = LanguageModelSession::new();
36//!     let async_session = AsyncSession::new(&session);
37//!     let reply = async_session.respond("Name three Norse gods.")?.await?;
38//!     println!("{}", reply.content);
39//!     Ok::<(), Box<dyn std::error::Error>>(())
40//! })
41//! # }
42//! ```
43
44use std::ffi::{c_void, CStr, CString};
45use std::future::Future;
46use std::pin::Pin;
47use std::task::{Context, Poll};
48
49use doom_fish_utils::completion::{error_from_cstr, AsyncCompletion, AsyncCompletionFuture};
50use serde::Deserialize;
51
52use crate::content::{BridgeGeneratedContent, GeneratedContent};
53use crate::error::FMError;
54use crate::ffi;
55use crate::generation::GenerationOptions;
56use crate::model::Adapter;
57use crate::prompt::{Prompt, ToPrompt};
58use crate::schema::GenerationSchema;
59use crate::session::{decode_bridge_text_response, SessionResponse};
60use crate::transcript::Transcript;
61
62// ============================================================================
63// Private bridge structs – mirror the JSON shapes emitted by SessionExtras.swift
64// ============================================================================
65
66#[derive(Debug, Deserialize)]
67struct AsyncBridgeStructuredResponse {
68    content: BridgeGeneratedContent,
69    #[serde(rename = "rawContent")]
70    raw_content: BridgeGeneratedContent,
71    #[serde(rename = "transcriptJSON")]
72    transcript_json: String,
73}
74
75// ============================================================================
76// Opaque pointer newtype – needed so AsyncCompletion<OpaquePtr> is Send
77// ============================================================================
78
79/// Thin Send-able wrapper around a raw opaque pointer returned by Swift.
80///
81/// # Safety
82///
83/// The pointer is a retained `AdapterBox` produced by
84/// `Unmanaged.passRetained(…).toOpaque()` on the Swift side.  We only
85/// ever pass it back to `fm_object_release`; we never dereference it in
86/// Rust.  Swift's reference counting is thread-safe, so `Send` is valid.
87struct OpaquePtr(*mut c_void);
88// SAFETY: See doc comment above.
89unsafe impl Send for OpaquePtr {}
90
91// ============================================================================
92// Callback: `FmRespondCallback` (4-arg) → AsyncCompletion<String>
93//
94// Reuses the existing `fm_session_respond_request_json` FFI which already
95// runs `try await session.respond(…)` inside a Swift Task.
96// ============================================================================
97
98/// Async respond callback.  Matches `ffi::FmRespondCallback`.
99///
100/// On success copies the JSON response to an owned `String` and completes
101/// the `AsyncCompletion`.  On failure maps the status + error to an
102/// `FMError` message string (the Future newtypes re-map that to `FMError`).
103///
104/// # Safety
105///
106/// `ctx` must be a valid `AsyncCompletion<String>` context pointer.
107/// `response` and `error` are nullable C strings owned by the Swift bridge.
108unsafe extern "C" fn respond_async_cb(
109    ctx: *mut c_void,
110    response: *mut std::ffi::c_char,
111    error: *mut std::ffi::c_char,
112    status: i32,
113) {
114    if status == ffi::status::OK && !response.is_null() {
115        let s = unsafe { CStr::from_ptr(response) }
116            .to_string_lossy()
117            .into_owned();
118        unsafe { ffi::fm_string_free(response) };
119        unsafe { AsyncCompletion::complete_ok(ctx, s) };
120    } else {
121        // Re-use the existing from_swift error mapper; convert FMError to String
122        // so we can store it in AsyncCompletion<String>.
123        let fm_err = crate::error::from_swift(status, error);
124        unsafe { AsyncCompletion::<String>::complete_err(ctx, fm_err.to_string()) };
125    }
126}
127
128// ============================================================================
129// Callback: 3-arg async callback → AsyncCompletion<OpaquePtr>
130//
131// Used by fm_adapter_create_from_name_async.
132// ============================================================================
133
134/// # Safety
135///
136/// `ctx` must be a valid `AsyncCompletion<OpaquePtr>` context pointer.
137unsafe extern "C" fn adapter_init_async_cb(
138    result: *mut c_void,
139    error: *const std::ffi::c_char,
140    ctx: *mut c_void,
141) {
142    if !error.is_null() {
143        let msg = unsafe { error_from_cstr(error) };
144        unsafe { AsyncCompletion::<OpaquePtr>::complete_err(ctx, msg) };
145    } else if !result.is_null() {
146        unsafe { AsyncCompletion::complete_ok(ctx, OpaquePtr(result)) };
147    } else {
148        unsafe { AsyncCompletion::<OpaquePtr>::complete_err(ctx, "null adapter pointer".into()) };
149    }
150}
151
152// ============================================================================
153// Callback: 3-arg async callback → AsyncCompletion<String>
154//
155// Used by fm_adapter_compatibility_async.  The result pointer is a strdup'd
156// JSON string; we copy it and free it.
157// ============================================================================
158
159/// # Safety
160///
161/// `ctx` must be a valid `AsyncCompletion<String>` context pointer.
162/// `result` (when non-null) must be a heap-allocated C string freed with
163/// `fm_string_free`.
164unsafe extern "C" fn adapter_compat_async_cb(
165    result: *mut c_void,
166    error: *const std::ffi::c_char,
167    ctx: *mut c_void,
168) {
169    if !error.is_null() {
170        let msg = unsafe { error_from_cstr(error) };
171        unsafe { AsyncCompletion::<String>::complete_err(ctx, msg) };
172    } else if !result.is_null() {
173        let s = unsafe { CStr::from_ptr(result.cast::<std::ffi::c_char>()) }
174            .to_string_lossy()
175            .into_owned();
176        // Free the strdup'd JSON string allocated by the Swift bridge.
177        unsafe { ffi::fm_string_free(result.cast::<std::ffi::c_char>()) };
178        unsafe { AsyncCompletion::complete_ok(ctx, s) };
179    } else {
180        unsafe { AsyncCompletion::<String>::complete_err(ctx, "null compatibility result".into()) };
181    }
182}
183
184// ============================================================================
185// RespondFuture — LanguageModelSession.respond(to:)
186// ============================================================================
187
188/// Future returned by [`AsyncSession::respond`].
189///
190/// Resolves to `Result<SessionResponse<String>, FMError>`.
191pub struct RespondFuture {
192    inner: AsyncCompletionFuture<String>,
193}
194
195impl std::fmt::Debug for RespondFuture {
196    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
197        f.debug_struct("RespondFuture").finish_non_exhaustive()
198    }
199}
200
201impl Future for RespondFuture {
202    type Output = Result<SessionResponse<String>, FMError>;
203
204    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
205        Pin::new(&mut self.inner).poll(cx).map(|r| {
206            r.map_err(|msg| FMError::Unknown {
207                code: ffi::status::UNKNOWN,
208                message: msg,
209            })
210            .and_then(|json| decode_bridge_text_response(&json))
211        })
212    }
213}
214
215// ============================================================================
216// RespondGeneratingFuture — LanguageModelSession.respond(to:generating:)
217// ============================================================================
218
219/// Future returned by [`AsyncSession::respond_generating`].
220///
221/// Resolves to `Result<SessionResponse<GeneratedContent>, FMError>`.
222pub struct RespondGeneratingFuture {
223    inner: AsyncCompletionFuture<String>,
224}
225
226impl std::fmt::Debug for RespondGeneratingFuture {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        f.debug_struct("RespondGeneratingFuture")
229            .finish_non_exhaustive()
230    }
231}
232
233impl Future for RespondGeneratingFuture {
234    type Output = Result<SessionResponse<GeneratedContent>, FMError>;
235
236    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237        Pin::new(&mut self.inner).poll(cx).map(|r| {
238            r.map_err(|msg| FMError::Unknown {
239                code: ffi::status::UNKNOWN,
240                message: msg,
241            })
242            .and_then(|json| {
243                let response: AsyncBridgeStructuredResponse = serde_json::from_str(&json)
244                    .map_err(|e| FMError::DecodingFailure(e.to_string()))?;
245                Ok(SessionResponse {
246                    content: GeneratedContent::from_bridge_payload(response.content, true)?,
247                    raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
248                    transcript: Transcript::from_json_str(&response.transcript_json)?,
249                })
250            })
251        })
252    }
253}
254
255// ============================================================================
256// AdapterInitFuture — SystemLanguageModel.Adapter init(name:)
257// ============================================================================
258
259/// Future returned by [`AsyncAdapter::from_name`].
260///
261/// Resolves to `Result<Adapter, FMError>`.
262pub struct AdapterInitFuture {
263    inner: AsyncCompletionFuture<OpaquePtr>,
264}
265
266impl std::fmt::Debug for AdapterInitFuture {
267    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
268        f.debug_struct("AdapterInitFuture").finish_non_exhaustive()
269    }
270}
271
272impl Future for AdapterInitFuture {
273    type Output = Result<Adapter, FMError>;
274
275    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
276        Pin::new(&mut self.inner).poll(cx).map(|r| {
277            r.map_err(FMError::AdapterInvalidName)
278                .map(|OpaquePtr(ptr)| Adapter { ptr })
279        })
280    }
281}
282
283// ============================================================================
284// AdapterCompatibilityFuture — SystemLanguageModel.Adapter.compatibility(for:)
285// ============================================================================
286
287/// Future returned by [`AsyncAdapter::compatibility`].
288///
289/// Resolves to `Result<Vec<String>, FMError>`.
290pub struct AdapterCompatibilityFuture {
291    inner: AsyncCompletionFuture<String>,
292}
293
294impl std::fmt::Debug for AdapterCompatibilityFuture {
295    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
296        f.debug_struct("AdapterCompatibilityFuture")
297            .finish_non_exhaustive()
298    }
299}
300
301impl Future for AdapterCompatibilityFuture {
302    type Output = Result<Vec<String>, FMError>;
303
304    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
305        Pin::new(&mut self.inner).poll(cx).map(|r| {
306            r.map_err(FMError::AdapterCompatibleNotFound)
307                .and_then(|json| {
308                    serde_json::from_str::<Vec<String>>(&json)
309                        .map_err(|e| FMError::DecodingFailure(e.to_string()))
310                })
311        })
312    }
313}
314
315// ============================================================================
316// AsyncSession — async wrapper around LanguageModelSession
317// ============================================================================
318
319/// Async wrapper around [`crate::LanguageModelSession`].
320///
321/// Borrows the session for its lifetime; the session itself must outlive all
322/// in-flight futures.
323///
324/// # Examples
325///
326/// ```rust,no_run
327/// use foundation_models::{LanguageModelSession, SystemLanguageModel};
328/// use foundation_models::async_api::AsyncSession;
329///
330/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
331/// if !SystemLanguageModel::is_available() { return Ok(()); }
332/// pollster::block_on(async {
333///     let session = LanguageModelSession::new();
334///     let reply = AsyncSession::new(&session).respond("Hi!")?.await?;
335///     println!("{}", reply.content);
336///     Ok::<(), Box<dyn std::error::Error>>(())
337/// })
338/// # }
339/// ```
340pub struct AsyncSession<'s> {
341    session: &'s crate::session::LanguageModelSession,
342}
343
344impl<'s> AsyncSession<'s> {
345    /// Wrap a [`crate::LanguageModelSession`] for async use.
346    #[must_use]
347    pub fn new(session: &'s crate::session::LanguageModelSession) -> Self {
348        Self { session }
349    }
350
351    /// Async version of `LanguageModelSession.respond(to:)`.
352    ///
353    /// Corresponds to the Swift `async throws` method
354    /// `LanguageModelSession.respond(to:)`.
355    ///
356    /// # Errors
357    ///
358    /// Returns an [`FMError`] if the model is unavailable or generation fails.
359    pub fn respond(&self, prompt: impl ToPrompt) -> Result<RespondFuture, FMError> {
360        let prompt = prompt.to_prompt()?;
361        let payload = build_text_request_json(&prompt, GenerationOptions::new())?;
362        let session_ptr = self.session.as_ptr();
363        let (future, ctx) = AsyncCompletion::create();
364        unsafe {
365            ffi::fm_session_respond_request_json(
366                session_ptr,
367                payload.as_ptr(),
368                ctx,
369                respond_async_cb,
370            );
371        }
372        Ok(RespondFuture { inner: future })
373    }
374
375    /// Async version of `LanguageModelSession.respond(to:)` with [`GenerationOptions`].
376    ///
377    /// # Errors
378    ///
379    /// Returns an [`FMError`] if the model is unavailable or generation fails.
380    pub fn respond_with_options(
381        &self,
382        prompt: impl ToPrompt,
383        options: GenerationOptions,
384    ) -> Result<RespondFuture, FMError> {
385        let prompt = prompt.to_prompt()?;
386        let payload = build_text_request_json(&prompt, options)?;
387        let session_ptr = self.session.as_ptr();
388        let (future, ctx) = AsyncCompletion::create();
389        unsafe {
390            ffi::fm_session_respond_request_json(
391                session_ptr,
392                payload.as_ptr(),
393                ctx,
394                respond_async_cb,
395            );
396        }
397        Ok(RespondFuture { inner: future })
398    }
399
400    /// Async version of `LanguageModelSession.respond(to:generating:)`.
401    ///
402    /// Generates a structured `GeneratedContent` response according to
403    /// `schema`.  Corresponds to the Swift `async throws` method
404    /// `LanguageModelSession.respond(to:generating:)`.
405    ///
406    /// # Errors
407    ///
408    /// Returns an [`FMError`] if the model is unavailable or generation fails.
409    pub fn respond_generating(
410        &self,
411        prompt: impl ToPrompt,
412        schema: &GenerationSchema,
413        include_schema_in_prompt: bool,
414        options: GenerationOptions,
415    ) -> Result<RespondGeneratingFuture, FMError> {
416        let prompt = prompt.to_prompt()?;
417        let payload =
418            build_structured_request_json(&prompt, options, schema, include_schema_in_prompt)?;
419        let session_ptr = self.session.as_ptr();
420        let (future, ctx) = AsyncCompletion::create();
421        unsafe {
422            ffi::fm_session_respond_request_json(
423                session_ptr,
424                payload.as_ptr(),
425                ctx,
426                respond_async_cb,
427            );
428        }
429        Ok(RespondGeneratingFuture { inner: future })
430    }
431}
432
433impl std::fmt::Debug for AsyncSession<'_> {
434    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
435        f.debug_struct("AsyncSession").finish_non_exhaustive()
436    }
437}
438
439// ============================================================================
440// AsyncAdapter — async adapter lifecycle
441// ============================================================================
442
443/// Namespace for async [`Adapter`] operations.
444///
445/// # Examples
446///
447/// ```rust,no_run
448/// use foundation_models::async_api::AsyncAdapter;
449///
450/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
451/// pollster::block_on(async {
452///     let ids = AsyncAdapter::compatibility("com.example.MyAdapter")?.await?;
453///     println!("compatible: {ids:?}");
454///     Ok::<(), Box<dyn std::error::Error>>(())
455/// })
456/// # }
457/// ```
458pub struct AsyncAdapter;
459
460impl AsyncAdapter {
461    /// Async version of `SystemLanguageModel.Adapter init(name:)`.
462    ///
463    /// Loads the named adapter asynchronously, returning a ready-to-use
464    /// [`Adapter`] handle.
465    ///
466    /// # Errors
467    ///
468    /// Returns an [`FMError::AdapterInvalidName`] if the adapter is not found
469    /// or the name contains a NUL byte.
470    pub fn from_name(name: &str) -> Result<AdapterInitFuture, FMError> {
471        let cname = CString::new(name)
472            .map_err(|e| FMError::InvalidArgument(format!("NUL byte in adapter name: {e}")))?;
473        let (future, ctx) = AsyncCompletion::create();
474        unsafe {
475            ffi::fm_adapter_create_from_name_async(cname.as_ptr(), ctx, adapter_init_async_cb);
476        }
477        Ok(AdapterInitFuture { inner: future })
478    }
479
480    /// Async version of `SystemLanguageModel.Adapter.compatibility(for:)`.
481    ///
482    /// Returns the list of compatible adapter identifiers for the given
483    /// logical adapter name.
484    ///
485    /// # Errors
486    ///
487    /// Returns an [`FMError::AdapterCompatibleNotFound`] on failure.
488    pub fn compatibility(name: &str) -> Result<AdapterCompatibilityFuture, FMError> {
489        let cname = CString::new(name)
490            .map_err(|e| FMError::InvalidArgument(format!("NUL byte in adapter name: {e}")))?;
491        let (future, ctx) = AsyncCompletion::create();
492        unsafe {
493            ffi::fm_adapter_compatibility_async(cname.as_ptr(), ctx, adapter_compat_async_cb);
494        }
495        Ok(AdapterCompatibilityFuture { inner: future })
496    }
497}
498
499// ============================================================================
500// Internal JSON request builders
501// ============================================================================
502
503fn build_text_request_json(
504    prompt: &Prompt,
505    options: GenerationOptions,
506) -> Result<CString, FMError> {
507    build_request_json_inner(prompt, options, None, true)
508}
509
510fn build_structured_request_json(
511    prompt: &Prompt,
512    options: GenerationOptions,
513    schema: &GenerationSchema,
514    include_schema_in_prompt: bool,
515) -> Result<CString, FMError> {
516    build_request_json_inner(prompt, options, Some(schema), include_schema_in_prompt)
517}
518
519fn build_request_json_inner(
520    prompt: &Prompt,
521    options: GenerationOptions,
522    schema: Option<&GenerationSchema>,
523    include_schema_in_prompt: bool,
524) -> Result<CString, FMError> {
525    use crate::generation::SamplingMode;
526    use serde_json::json;
527
528    let sampling = match options.sampling() {
529        SamplingMode::Default => json!({ "mode": "default" }),
530        SamplingMode::Greedy => json!({ "mode": "greedy" }),
531        SamplingMode::TopK(k) => json!({
532            "mode": "top_k",
533            "topK": k,
534            "seed": options.sampling_seed(),
535        }),
536        SamplingMode::TopP(p) => json!({
537            "mode": "top_p",
538            "topP": p,
539            "seed": options.sampling_seed(),
540        }),
541    };
542    let include_schema_in_prompt = schema.map_or(include_schema_in_prompt, |schema| {
543        schema.effective_include_schema_in_prompt(include_schema_in_prompt)
544    });
545    let payload = serde_json::to_string(&json!({
546        "prompt": prompt.to_bridge_value(),
547        "options": {
548            "temperature": options.temperature(),
549            "maximumResponseTokens": options.maximum_response_tokens(),
550            "sampling": sampling,
551        },
552        "schemaJSON": schema.map(GenerationSchema::bridge_request_json),
553        "includeSchemaInPrompt": include_schema_in_prompt,
554    }))
555    .map_err(|e| FMError::InvalidArgument(format!("request not JSON-serializable: {e}")))?;
556    CString::new(payload)
557        .map_err(|e| FMError::InvalidArgument(format!("request JSON contains NUL: {e}")))
558}