Skip to main content

chat_applefm/
client.rs

1use std::path::PathBuf;
2use std::sync::Arc;
3
4use chat_core::types::provider_meta::ProviderMeta;
5
6use crate::api::types::request::{ConvoEntry, TurnPlan, hash_convo};
7use crate::ffi;
8
9/// Decoding strategy for the on-device model — the full set
10/// FoundationModels exposes via `GenerationOptions.SamplingMode`.
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum Sampling {
13    /// Deterministic decoding.
14    Greedy,
15    /// Sample among the `k` most probable tokens.
16    TopK { k: u32, seed: Option<u64> },
17    /// Nucleus sampling: sample within the smallest set of tokens whose
18    /// cumulative probability reaches `p`.
19    TopP { p: f64, seed: Option<u64> },
20}
21
22/// Model wiring baked in by the builder. Conversation concerns —
23/// including system prompts — live in `Messages`, not here. Everything
24/// below `lora` is a default; per-call `ChatOptions` override it.
25#[derive(Debug, Default)]
26pub(crate) struct Config {
27    /// Path to a `.fmadapter` package — a LoRA trained with Apple's
28    /// adapter training toolkit, applied over the on-device base model.
29    pub(crate) lora: Option<PathBuf>,
30    pub(crate) temperature: Option<f64>,
31    pub(crate) max_tokens: Option<u32>,
32    pub(crate) sampling: Option<Sampling>,
33}
34
35/// Owned handle to a live bridge session; releases it on drop.
36#[derive(Debug)]
37struct SessionHandle(u64);
38
39impl Drop for SessionHandle {
40    fn drop(&mut self) {
41        ffi::session_free(self.0);
42    }
43}
44
45#[derive(Debug)]
46struct SessionState {
47    handle: SessionHandle,
48    instructions_hash: u64,
49    /// Hash of every conversation entry the session has seen, including
50    /// model replies.
51    prefix_hash: u64,
52    prefix_len: usize,
53}
54
55/// The client's slot for a live bridge session, with the fingerprint
56/// that recognizes append-only continuations. All lifecycle transitions
57/// go through these methods; nothing else touches the state.
58#[derive(Debug, Default)]
59pub(crate) struct Session(Option<SessionState>);
60
61impl Session {
62    /// Decide how to run this turn: incremental prefill against the held
63    /// session, or tear down and rebuild. Reuse requires the same
64    /// instructions and a conversation that extends what the session has
65    /// seen by exactly one message.
66    pub(crate) fn plan(&self, instructions_hash: u64, convo: &[ConvoEntry]) -> TurnPlan {
67        match &self.0 {
68            Some(s)
69                if s.instructions_hash == instructions_hash
70                    && convo.len() == s.prefix_len + 1
71                    && hash_convo(&convo[..s.prefix_len]) == s.prefix_hash =>
72            {
73                TurnPlan::Reuse
74            }
75            _ => TurnPlan::Rebuild,
76        }
77    }
78
79    /// The live bridge session id, if any.
80    pub(crate) fn id(&self) -> Option<u64> {
81        self.0.as_ref().map(|s| s.handle.0)
82    }
83
84    /// Drop the held session (frees the bridge side). Called on rebuild
85    /// and whenever a turn errors — the bridge session may then hold a
86    /// half-applied turn, so the next call starts fresh.
87    pub(crate) fn invalidate(&mut self) {
88        self.0 = None;
89    }
90
91    /// Replace with a freshly created bridge session that has seen
92    /// nothing yet.
93    pub(crate) fn install(&mut self, id: u64, instructions_hash: u64) {
94        self.0 = Some(SessionState {
95            handle: SessionHandle(id),
96            instructions_hash,
97            prefix_hash: hash_convo(&[]),
98            prefix_len: 0,
99        });
100    }
101
102    /// Record that the session has now seen `convo` plus the model's
103    /// reply, so the next append-only turn plans as `Reuse`.
104    pub(crate) fn advance(&mut self, mut convo: Vec<ConvoEntry>, reply_text: String) {
105        if let Some(s) = &mut self.0 {
106            convo.push(ConvoEntry {
107                role: "assistant",
108                text: reply_text,
109            });
110            s.prefix_len = convo.len();
111            s.prefix_hash = hash_convo(&convo);
112        }
113    }
114}
115
116/// Client for the Apple Intelligence on-device foundation model.
117///
118/// There is no model slug and no API key: the OS owns the (one) model.
119/// What varies per client is the configuration — a LoRA adapter and
120/// decoding defaults.
121///
122/// The client holds a live session across turns: append-only
123/// conversations prefill only the newest message, which is what makes
124/// multi-turn chat fast. Clones share that session, so use one client
125/// per concurrent conversation.
126#[derive(Clone, Debug)]
127pub struct AppleFMClient {
128    pub(crate) config: Arc<Config>,
129    /// `Arc` because `ProviderMeta` is not `Clone`.
130    pub(crate) meta: Arc<ProviderMeta>,
131    /// Locked for the whole of each turn — serializes calls so the
132    /// bridge session is never used concurrently.
133    pub(crate) session: Arc<tokio::sync::Mutex<Session>>,
134}
135
136impl AppleFMClient {
137    /// Identifier used as the response `model_slug`: the base model name,
138    /// plus the adapter file stem when a LoRA is loaded.
139    pub fn model_slug(&self) -> String {
140        match self.config.lora.as_deref().and_then(|p| p.file_stem()) {
141            Some(stem) => format!("apple-on-device+{}", stem.to_string_lossy()),
142            None => "apple-on-device".to_owned(),
143        }
144    }
145
146    pub fn provider_meta(&self) -> &ProviderMeta {
147        &self.meta
148    }
149
150    /// Hint the OS to stage model resources for an upcoming turn.
151    ///
152    /// The runtime stages the model down between requests, so a turn that
153    /// follows an idle pause (a user typing, say) pays seconds of warm-up.
154    /// Call this when you can predict a turn is coming — on input focus,
155    /// when the user starts typing — and the warm-up overlaps the wait.
156    ///
157    /// Fire-and-forget and cheap: returns immediately, never fails, and
158    /// is skipped entirely if a turn is already in flight (the model is
159    /// active then anyway). Works before the first turn too.
160    pub fn prewarm(&self) {
161        if let Ok(session) = self.session.try_lock() {
162            let id = session.id().unwrap_or(0);
163            drop(session);
164            ffi::prewarm(id);
165        }
166    }
167
168    /// Fill the metadata the bridge can't: wall-clock duration and
169    /// creation time measured around the call, the active LoRA adapter,
170    /// and whether the turn reused the session (`prefill:
171    /// "incremental"`) or rebuilt it (`prefill: "full"`). Token usage
172    /// stays zero — Apple's FoundationModels API does not expose token
173    /// counts.
174    pub(crate) fn enrich_metadata(
175        &self,
176        metadata: &mut chat_core::types::metadata::Metadata,
177        elapsed: std::time::Duration,
178        reused_session: bool,
179    ) {
180        metadata.duration_ms = Some(elapsed.as_millis() as u64);
181        metadata.created_at = std::time::SystemTime::now()
182            .duration_since(std::time::UNIX_EPOCH)
183            .ok()
184            .map(|d| d.as_secs());
185        metadata.provider_specific.insert(
186            "prefill".to_owned(),
187            serde_json::Value::String(
188                if reused_session {
189                    "incremental"
190                } else {
191                    "full"
192                }
193                .to_owned(),
194            ),
195        );
196        if let Some(lora) = &self.config.lora {
197            metadata.provider_specific.insert(
198                "lora".to_owned(),
199                serde_json::Value::String(lora.to_string_lossy().into_owned()),
200            );
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::api::types::request::hash_instructions;
209
210    fn entry(role: &'static str, text: &str) -> ConvoEntry {
211        ConvoEntry {
212            role,
213            text: text.to_owned(),
214        }
215    }
216
217    #[test]
218    fn session_lifecycle_plans_correctly() {
219        let instructions_hash = hash_instructions(Some("sys"));
220        let mut session = Session::default();
221
222        // No session yet → rebuild.
223        let convo1 = vec![entry("user", "hi")];
224        assert_eq!(session.plan(instructions_hash, &convo1), TurnPlan::Rebuild);
225
226        // Simulate the first turn: install, then advance with the reply.
227        session.install(1, instructions_hash);
228        session.advance(convo1, "yo".to_owned());
229
230        // Append-only next turn → reuse.
231        let convo2 = vec![
232            entry("user", "hi"),
233            entry("assistant", "yo"),
234            entry("user", "how are you?"),
235        ];
236        assert_eq!(session.plan(instructions_hash, &convo2), TurnPlan::Reuse);
237
238        // Edited history → rebuild.
239        let edited = vec![
240            entry("user", "hi EDITED"),
241            entry("assistant", "yo"),
242            entry("user", "how are you?"),
243        ];
244        assert_eq!(session.plan(instructions_hash, &edited), TurnPlan::Rebuild);
245
246        // Changed instructions → rebuild.
247        assert_eq!(
248            session.plan(hash_instructions(Some("other")), &convo2),
249            TurnPlan::Rebuild
250        );
251
252        // Invalidated → rebuild.
253        session.invalidate();
254        assert_eq!(session.id(), None);
255        assert_eq!(session.plan(instructions_hash, &convo2), TurnPlan::Rebuild);
256    }
257}