Skip to main content

chat_applefm/
builder.rs

1use std::path::PathBuf;
2use std::sync::Arc;
3
4use chat_core::error::{ChatError, ChatFailure};
5use chat_core::types::provider_meta::ProviderMeta;
6
7use crate::client::{AppleFMClient, Config, Sampling};
8
9/// Builder for [`AppleFMClient`].
10///
11/// A pure connection stub: it wires up *which model variant* to talk to
12/// (base, or base + LoRA) and nothing else. Unlike other providers there
13/// is no `with_model` typestate — the model is the one the OS ships.
14/// Conversation concerns (system prompts, options) belong to the chat:
15/// push a `System`-role message into `Messages` and the provider maps it
16/// onto the session's instructions.
17///
18/// ```no_run
19/// # fn run() -> Result<(), Box<dyn std::error::Error>> {
20/// use chat_applefm::AppleFMBuilder;
21///
22/// let client = AppleFMBuilder::new()
23///     .with_lora("adapters/transcripts.fmadapter")
24///     .build()?;
25/// # Ok(()) }
26/// ```
27#[derive(Debug, Default)]
28pub struct AppleFMBuilder {
29    lora: Option<PathBuf>,
30    temperature: Option<f64>,
31    max_tokens: Option<u32>,
32    sampling: Option<Sampling>,
33    description: Option<String>,
34}
35
36impl AppleFMBuilder {
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Apply a LoRA fine-tune over the on-device base model: the path to a
42    /// `.fmadapter` package produced by Apple's adapter training toolkit.
43    ///
44    /// Adapters are tied to a specific base-model version — when a macOS
45    /// update rolls the base model, the adapter must be retrained. Loading
46    /// an incompatible adapter fails at request time with a provider error.
47    pub fn with_lora(mut self, path: impl Into<PathBuf>) -> Self {
48        self.lora = Some(path.into());
49        self
50    }
51
52    /// Default decoding temperature. Overridden per call by
53    /// `ChatOptions::temperature`.
54    pub fn with_temperature(mut self, temperature: f64) -> Self {
55        self.temperature = Some(temperature);
56        self
57    }
58
59    /// Default response-length cap. Overridden per call by
60    /// `ChatOptions::max_tokens`.
61    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
62        self.max_tokens = Some(max_tokens);
63        self
64    }
65
66    /// Default sampling mode — greedy, top-k, or top-p (the complete set
67    /// FoundationModels exposes). Overridden per call when `ChatOptions`
68    /// carries any sampling key (`top_p`, or `greedy` / `top_k` / `seed`
69    /// in its metadata).
70    pub fn with_sampling(mut self, sampling: Sampling) -> Self {
71        self.sampling = Some(sampling);
72        self
73    }
74
75    pub fn with_description(mut self, description: impl Into<String>) -> Self {
76        self.description = Some(description.into());
77        self
78    }
79
80    /// Build the client, validating the configuration upfront — a missing
81    /// `.fmadapter` path or nonsensical sampling parameters fail here, not
82    /// on the first request. Cheap otherwise: the model is probed and the
83    /// session created at request time. Call [`crate::availability`]
84    /// first to know whether requests can succeed on this machine at all.
85    pub fn build(self) -> Result<AppleFMClient, ChatFailure> {
86        if let Some(lora) = &self.lora
87            && !lora.exists()
88        {
89            return Err(invalid(format!(
90                "LoRA adapter not found at {}",
91                lora.display()
92            )));
93        }
94        if let Some(t) = self.temperature
95            && (!t.is_finite() || t < 0.0)
96        {
97            return Err(invalid(format!("temperature must be >= 0, got {t}")));
98        }
99        if self.max_tokens == Some(0) {
100            return Err(invalid("max_tokens must be >= 1".into()));
101        }
102        match self.sampling {
103            Some(Sampling::TopK { k: 0, .. }) => {
104                return Err(invalid("top-k sampling needs k >= 1".into()));
105            }
106            Some(Sampling::TopP { p, .. }) if !p.is_finite() || p <= 0.0 || p > 1.0 => {
107                return Err(invalid(format!(
108                    "top-p sampling needs p in (0, 1], got {p}"
109                )));
110            }
111            _ => {}
112        }
113
114        Ok(AppleFMClient {
115            config: Arc::new(Config {
116                lora: self.lora,
117                temperature: self.temperature,
118                max_tokens: self.max_tokens,
119                sampling: self.sampling,
120            }),
121            meta: Arc::new(ProviderMeta {
122                description: self.description,
123                ..Default::default()
124            }),
125            session: Arc::new(tokio::sync::Mutex::new(crate::client::Session::default())),
126        })
127    }
128}
129
130fn invalid(message: String) -> ChatFailure {
131    ChatFailure::from_err(ChatError::Provider(format!(
132        "chat-applefm builder: {message}"
133    )))
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn rejects_missing_lora_path() {
142        let err = AppleFMBuilder::new()
143            .with_lora("/definitely/not/here.fmadapter")
144            .build()
145            .unwrap_err();
146        assert!(err.err.to_string().contains("not found"));
147    }
148
149    #[test]
150    fn accepts_existing_lora_path() {
151        // Any path that exists works for the check; the manifest is one.
152        let manifest = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml");
153        assert!(AppleFMBuilder::new().with_lora(manifest).build().is_ok());
154    }
155
156    #[test]
157    fn rejects_bad_sampling_and_options() {
158        assert!(
159            AppleFMBuilder::new()
160                .with_sampling(Sampling::TopP { p: 1.5, seed: None })
161                .build()
162                .is_err()
163        );
164        assert!(
165            AppleFMBuilder::new()
166                .with_sampling(Sampling::TopK { k: 0, seed: None })
167                .build()
168                .is_err()
169        );
170        assert!(
171            AppleFMBuilder::new()
172                .with_temperature(-0.1)
173                .build()
174                .is_err()
175        );
176        assert!(AppleFMBuilder::new().with_max_tokens(0).build().is_err());
177        assert!(AppleFMBuilder::new().build().is_ok());
178    }
179}