1use std::ffi::{c_void, CStr, CString};
46use std::future::Future;
47use std::pin::Pin;
48use std::task::{Context, Poll};
49
50use doom_fish_utils::completion::{error_from_cstr, AsyncCompletion, AsyncCompletionFuture};
51use serde::Deserialize;
52
53use crate::content::{BridgeGeneratedContent, GeneratedContent};
54use crate::error::FMError;
55use crate::ffi;
56use crate::generation::GenerationOptions;
57use crate::model::Adapter;
58use crate::prompt::{Prompt, ToPrompt};
59use crate::schema::GenerationSchema;
60use crate::session::{decode_bridge_text_response, SessionResponse};
61use crate::transcript::Transcript;
62
63#[derive(Debug, Deserialize)]
68struct AsyncBridgeStructuredResponse {
69 content: BridgeGeneratedContent,
70 #[serde(rename = "rawContent")]
71 raw_content: BridgeGeneratedContent,
72 #[serde(rename = "transcriptJSON")]
73 transcript_json: String,
74}
75
76struct OpaquePtr(*mut c_void);
89unsafe impl Send for OpaquePtr {}
91
92unsafe extern "C" fn respond_async_cb(
110 ctx: *mut c_void,
111 response: *mut std::ffi::c_char,
112 error: *mut std::ffi::c_char,
113 status: i32,
114) {
115 if status == ffi::status::OK && !response.is_null() {
116 let s = unsafe { CStr::from_ptr(response) }
117 .to_string_lossy()
118 .into_owned();
119 unsafe { ffi::fm_string_free(response) };
120 unsafe { AsyncCompletion::complete_ok(ctx, s) };
121 } else {
122 let fm_err = crate::error::from_swift(status, error);
125 unsafe { AsyncCompletion::<String>::complete_err(ctx, fm_err.to_string()) };
126 }
127}
128
129unsafe extern "C" fn adapter_init_async_cb(
139 result: *mut c_void,
140 error: *const std::ffi::c_char,
141 ctx: *mut c_void,
142) {
143 if !error.is_null() {
144 let msg = unsafe { error_from_cstr(error) };
145 unsafe { AsyncCompletion::<OpaquePtr>::complete_err(ctx, msg) };
146 } else if !result.is_null() {
147 unsafe { AsyncCompletion::complete_ok(ctx, OpaquePtr(result)) };
148 } else {
149 unsafe { AsyncCompletion::<OpaquePtr>::complete_err(ctx, "null adapter pointer".into()) };
150 }
151}
152
153unsafe extern "C" fn adapter_compat_async_cb(
166 result: *mut c_void,
167 error: *const std::ffi::c_char,
168 ctx: *mut c_void,
169) {
170 if !error.is_null() {
171 let msg = unsafe { error_from_cstr(error) };
172 unsafe { AsyncCompletion::<String>::complete_err(ctx, msg) };
173 } else if !result.is_null() {
174 let s = unsafe { CStr::from_ptr(result.cast::<std::ffi::c_char>()) }
175 .to_string_lossy()
176 .into_owned();
177 unsafe { ffi::fm_string_free(result.cast::<std::ffi::c_char>()) };
179 unsafe { AsyncCompletion::complete_ok(ctx, s) };
180 } else {
181 unsafe { AsyncCompletion::<String>::complete_err(ctx, "null compatibility result".into()) };
182 }
183}
184
185unsafe extern "C" fn adapter_compile_async_cb(
189 context: *mut c_void,
190 response: *mut std::ffi::c_char,
191 error: *mut std::ffi::c_char,
192 status: i32,
193) {
194 if !response.is_null() {
195 unsafe { ffi::fm_string_free(response) };
196 }
197
198 if status == ffi::status::OK {
199 if !error.is_null() {
200 unsafe { ffi::fm_string_free(error) };
201 }
202 unsafe { AsyncCompletion::complete_ok(context, ()) };
203 return;
204 }
205
206 let message = if error.is_null() {
207 crate::error::from_swift(status, std::ptr::null_mut()).to_string()
208 } else {
209 let message = unsafe { CStr::from_ptr(error) }.to_string_lossy().into_owned();
210 unsafe { ffi::fm_string_free(error) };
211 message
212 };
213 unsafe { AsyncCompletion::<()>::complete_err(context, message) };
214}
215
216pub struct RespondFuture {
224 inner: AsyncCompletionFuture<String>,
225}
226
227impl std::fmt::Debug for RespondFuture {
228 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 f.debug_struct("RespondFuture").finish_non_exhaustive()
230 }
231}
232
233impl Future for RespondFuture {
234 type Output = Result<SessionResponse<String>, FMError>;
235
236 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237 Pin::new(&mut self.inner).poll(cx).map(|r| {
238 r.map_err(|msg| FMError::Unknown {
239 code: ffi::status::UNKNOWN,
240 message: msg,
241 })
242 .and_then(|json| decode_bridge_text_response(&json))
243 })
244 }
245}
246
247pub struct RespondGeneratingFuture {
255 inner: AsyncCompletionFuture<String>,
256}
257
258impl std::fmt::Debug for RespondGeneratingFuture {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 f.debug_struct("RespondGeneratingFuture")
261 .finish_non_exhaustive()
262 }
263}
264
265impl Future for RespondGeneratingFuture {
266 type Output = Result<SessionResponse<GeneratedContent>, FMError>;
267
268 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
269 Pin::new(&mut self.inner).poll(cx).map(|r| {
270 r.map_err(|msg| FMError::Unknown {
271 code: ffi::status::UNKNOWN,
272 message: msg,
273 })
274 .and_then(|json| {
275 let response: AsyncBridgeStructuredResponse = serde_json::from_str(&json)
276 .map_err(|e| FMError::DecodingFailure(e.to_string()))?;
277 Ok(SessionResponse {
278 content: GeneratedContent::from_bridge_payload(response.content, true)?,
279 raw_content: GeneratedContent::from_bridge_payload(response.raw_content, true)?,
280 transcript: Transcript::from_json_str(&response.transcript_json)?,
281 })
282 })
283 })
284 }
285}
286
287pub struct AdapterInitFuture {
295 inner: AsyncCompletionFuture<OpaquePtr>,
296}
297
298impl std::fmt::Debug for AdapterInitFuture {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 f.debug_struct("AdapterInitFuture").finish_non_exhaustive()
301 }
302}
303
304impl Future for AdapterInitFuture {
305 type Output = Result<Adapter, FMError>;
306
307 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
308 Pin::new(&mut self.inner).poll(cx).map(|r| {
309 r.map_err(FMError::AdapterInvalidName)
310 .map(|OpaquePtr(ptr)| Adapter { ptr })
311 })
312 }
313}
314
315pub struct AdapterCompatibilityFuture {
323 inner: AsyncCompletionFuture<String>,
324}
325
326impl std::fmt::Debug for AdapterCompatibilityFuture {
327 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328 f.debug_struct("AdapterCompatibilityFuture")
329 .finish_non_exhaustive()
330 }
331}
332
333impl Future for AdapterCompatibilityFuture {
334 type Output = Result<Vec<String>, FMError>;
335
336 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
337 Pin::new(&mut self.inner).poll(cx).map(|r| {
338 r.map_err(FMError::AdapterCompatibleNotFound)
339 .and_then(|json| {
340 serde_json::from_str::<Vec<String>>(&json)
341 .map_err(|e| FMError::DecodingFailure(e.to_string()))
342 })
343 })
344 }
345}
346
347pub struct CompileAdapterFuture {
351 inner: AsyncCompletionFuture<()>,
352}
353
354impl std::fmt::Debug for CompileAdapterFuture {
355 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
356 f.debug_struct("CompileAdapterFuture")
357 .finish_non_exhaustive()
358 }
359}
360
361impl Future for CompileAdapterFuture {
362 type Output = Result<(), FMError>;
363
364 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
365 Pin::new(&mut self.inner).poll(cx).map(|r| {
366 r.map_err(|message| FMError::Unknown {
367 code: ffi::status::UNKNOWN,
368 message,
369 })
370 })
371 }
372}
373
374pub struct AsyncSession<'s> {
400 session: &'s crate::session::LanguageModelSession,
401}
402
403impl<'s> AsyncSession<'s> {
404 #[must_use]
406 pub fn new(session: &'s crate::session::LanguageModelSession) -> Self {
407 Self { session }
408 }
409
410 pub fn respond(&self, prompt: impl ToPrompt) -> Result<RespondFuture, FMError> {
419 let prompt = prompt.to_prompt()?;
420 let payload = build_text_request_json(&prompt, GenerationOptions::new())?;
421 let session_ptr = self.session.as_ptr();
422 let (future, ctx) = AsyncCompletion::create();
423 unsafe {
424 ffi::fm_session_respond_request_json(
425 session_ptr,
426 payload.as_ptr(),
427 ctx,
428 respond_async_cb,
429 );
430 }
431 Ok(RespondFuture { inner: future })
432 }
433
434 pub fn respond_with_options(
440 &self,
441 prompt: impl ToPrompt,
442 options: GenerationOptions,
443 ) -> Result<RespondFuture, FMError> {
444 let prompt = prompt.to_prompt()?;
445 let payload = build_text_request_json(&prompt, options)?;
446 let session_ptr = self.session.as_ptr();
447 let (future, ctx) = AsyncCompletion::create();
448 unsafe {
449 ffi::fm_session_respond_request_json(
450 session_ptr,
451 payload.as_ptr(),
452 ctx,
453 respond_async_cb,
454 );
455 }
456 Ok(RespondFuture { inner: future })
457 }
458
459 pub fn respond_generating(
469 &self,
470 prompt: impl ToPrompt,
471 schema: &GenerationSchema,
472 include_schema_in_prompt: bool,
473 options: GenerationOptions,
474 ) -> Result<RespondGeneratingFuture, FMError> {
475 let prompt = prompt.to_prompt()?;
476 let payload =
477 build_structured_request_json(&prompt, options, schema, include_schema_in_prompt)?;
478 let session_ptr = self.session.as_ptr();
479 let (future, ctx) = AsyncCompletion::create();
480 unsafe {
481 ffi::fm_session_respond_request_json(
482 session_ptr,
483 payload.as_ptr(),
484 ctx,
485 respond_async_cb,
486 );
487 }
488 Ok(RespondGeneratingFuture { inner: future })
489 }
490}
491
492impl std::fmt::Debug for AsyncSession<'_> {
493 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494 f.debug_struct("AsyncSession").finish_non_exhaustive()
495 }
496}
497
498pub struct AsyncAdapter;
518
519impl AsyncAdapter {
520 pub fn from_name(name: &str) -> Result<AdapterInitFuture, FMError> {
530 let cname = CString::new(name)
531 .map_err(|e| FMError::InvalidArgument(format!("NUL byte in adapter name: {e}")))?;
532 let (future, ctx) = AsyncCompletion::create();
533 unsafe {
534 ffi::fm_adapter_create_from_name_async(cname.as_ptr(), ctx, adapter_init_async_cb);
535 }
536 Ok(AdapterInitFuture { inner: future })
537 }
538
539 pub fn compatibility(name: &str) -> Result<AdapterCompatibilityFuture, FMError> {
548 let cname = CString::new(name)
549 .map_err(|e| FMError::InvalidArgument(format!("NUL byte in adapter name: {e}")))?;
550 let (future, ctx) = AsyncCompletion::create();
551 unsafe {
552 ffi::fm_adapter_compatibility_async(cname.as_ptr(), ctx, adapter_compat_async_cb);
553 }
554 Ok(AdapterCompatibilityFuture { inner: future })
555 }
556
557 #[must_use]
559 pub fn compile(adapter: &Adapter) -> CompileAdapterFuture {
560 let (future, ctx) = AsyncCompletion::create();
561 unsafe {
562 ffi::fm_adapter_compile(adapter.ptr, ctx, adapter_compile_async_cb);
563 }
564 CompileAdapterFuture { inner: future }
565 }
566}
567
568fn build_text_request_json(
573 prompt: &Prompt,
574 options: GenerationOptions,
575) -> Result<CString, FMError> {
576 build_request_json_inner(prompt, options, None, true)
577}
578
579fn build_structured_request_json(
580 prompt: &Prompt,
581 options: GenerationOptions,
582 schema: &GenerationSchema,
583 include_schema_in_prompt: bool,
584) -> Result<CString, FMError> {
585 build_request_json_inner(prompt, options, Some(schema), include_schema_in_prompt)
586}
587
588fn build_request_json_inner(
589 prompt: &Prompt,
590 options: GenerationOptions,
591 schema: Option<&GenerationSchema>,
592 include_schema_in_prompt: bool,
593) -> Result<CString, FMError> {
594 use crate::generation::SamplingMode;
595 use serde_json::json;
596
597 let sampling = match options.sampling() {
598 SamplingMode::Default => json!({ "mode": "default" }),
599 SamplingMode::Greedy => json!({ "mode": "greedy" }),
600 SamplingMode::TopK(k) => json!({
601 "mode": "top_k",
602 "topK": k,
603 "seed": options.sampling_seed(),
604 }),
605 SamplingMode::TopP(p) => json!({
606 "mode": "top_p",
607 "topP": p,
608 "seed": options.sampling_seed(),
609 }),
610 };
611 let include_schema_in_prompt = schema.map_or(include_schema_in_prompt, |schema| {
612 schema.effective_include_schema_in_prompt(include_schema_in_prompt)
613 });
614 let payload = serde_json::to_string(&json!({
615 "prompt": prompt.to_bridge_value(),
616 "options": {
617 "temperature": options.temperature(),
618 "maximumResponseTokens": options.maximum_response_tokens(),
619 "sampling": sampling,
620 },
621 "schemaJSON": schema.map(GenerationSchema::bridge_request_json),
622 "includeSchemaInPrompt": include_schema_in_prompt,
623 }))
624 .map_err(|e| FMError::InvalidArgument(format!("request not JSON-serializable: {e}")))?;
625 CString::new(payload)
626 .map_err(|e| FMError::InvalidArgument(format!("request JSON contains NUL: {e}")))
627}