Skip to main content

chat_applefm/api/
completion.rs

1use async_trait::async_trait;
2use chat_core::error::{ChatError, ChatFailure};
3use chat_core::traits::CompletionProvider;
4use chat_core::types::messages::Messages;
5use chat_core::types::options::ChatOptions;
6use chat_core::types::provider_meta::ProviderMeta;
7use chat_core::types::response::ChatResponse;
8use chat_core::types::tools::ToolDeclarations;
9
10use crate::api::types::request::{self, TurnPlan};
11use crate::api::types::response;
12use crate::client::AppleFMClient;
13use crate::ffi;
14
15#[async_trait]
16impl CompletionProvider for AppleFMClient {
17    async fn complete(
18        &mut self,
19        messages: &mut Messages,
20        tool_declarations: Option<&dyn ToolDeclarations>,
21        options: Option<&ChatOptions>,
22        structured_output: Option<&schemars::Schema>,
23    ) -> Result<ChatResponse, ChatFailure> {
24        let (instructions, convo) =
25            request::prepare(messages, structured_output, tool_declarations.is_some())?;
26        let wire_options = request::merge_options(&self.config, options);
27        let instructions_hash = request::hash_instructions(instructions.as_deref());
28
29        // Held for the whole turn: serializes use of the bridge session.
30        let mut session = self.session.lock().await;
31
32        let reused = match session.plan(instructions_hash, &convo) {
33            TurnPlan::Reuse => true,
34            TurnPlan::Rebuild => {
35                session.invalidate();
36                let config_json =
37                    request::session_config_json(instructions.as_deref(), &self.config)?;
38                let created_json =
39                    tokio::task::spawn_blocking(move || ffi::session_create(&config_json))
40                        .await
41                        .map_err(join_error)?;
42                session.install(
43                    response::parse_session_created(&created_json)?,
44                    instructions_hash,
45                );
46                false
47            }
48        };
49
50        let message = if reused {
51            convo
52                .last()
53                .expect("prepare guarantees non-empty")
54                .text
55                .clone()
56        } else {
57            request::render_full(&convo)
58        };
59        let turn_json = request::turn_request_json(message, wire_options)?;
60        let session_id = session.id().expect("session installed above");
61
62        // The bridge call blocks (model inference); keep it off the
63        // async workers.
64        let started = std::time::Instant::now();
65        let reply_json =
66            tokio::task::spawn_blocking(move || ffi::session_respond(session_id, &turn_json))
67                .await
68                .map_err(join_error)?;
69
70        match response::into_core(&self.model_slug(), &reply_json) {
71            Ok(mut chat_response) => {
72                // Advance the fingerprint past this turn (the chat loop
73                // appends the reply to `Messages`, so the next call sees
74                // convo + reply + new user message).
75                let reply_text = chat_response
76                    .content
77                    .parts
78                    .text_response()
79                    .map(|t| t.as_str().to_owned())
80                    .unwrap_or_default();
81                session.advance(convo, reply_text);
82
83                if let Some(metadata) = chat_response.metadata.as_mut() {
84                    self.enrich_metadata(metadata, started.elapsed(), reused);
85                }
86                Ok(chat_response)
87            }
88            Err(failure) => {
89                // The bridge session may hold a half-applied turn.
90                session.invalidate();
91                Err(failure)
92            }
93        }
94    }
95
96    fn metadata(&self) -> Option<&ProviderMeta> {
97        Some(&self.meta)
98    }
99}
100
101fn join_error(e: tokio::task::JoinError) -> ChatFailure {
102    ChatFailure::from_err(ChatError::Other(format!("bridge task failed: {e}")))
103}