1use std::path::PathBuf;
2use std::sync::Arc;
3
4use chat_core::error::{ChatError, ChatFailure};
5use chat_core::types::provider_meta::ProviderMeta;
6
7use crate::client::{AppleFMClient, Config, Sampling};
8
9#[derive(Debug, Default)]
28pub struct AppleFMBuilder {
29 lora: Option<PathBuf>,
30 temperature: Option<f64>,
31 max_tokens: Option<u32>,
32 sampling: Option<Sampling>,
33 description: Option<String>,
34}
35
36impl AppleFMBuilder {
37 pub fn new() -> Self {
38 Self::default()
39 }
40
41 pub fn with_lora(mut self, path: impl Into<PathBuf>) -> Self {
48 self.lora = Some(path.into());
49 self
50 }
51
52 pub fn with_temperature(mut self, temperature: f64) -> Self {
55 self.temperature = Some(temperature);
56 self
57 }
58
59 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
62 self.max_tokens = Some(max_tokens);
63 self
64 }
65
66 pub fn with_sampling(mut self, sampling: Sampling) -> Self {
71 self.sampling = Some(sampling);
72 self
73 }
74
75 pub fn with_description(mut self, description: impl Into<String>) -> Self {
76 self.description = Some(description.into());
77 self
78 }
79
80 pub fn build(self) -> Result<AppleFMClient, ChatFailure> {
86 if let Some(lora) = &self.lora
87 && !lora.exists()
88 {
89 return Err(invalid(format!(
90 "LoRA adapter not found at {}",
91 lora.display()
92 )));
93 }
94 if let Some(t) = self.temperature
95 && (!t.is_finite() || t < 0.0)
96 {
97 return Err(invalid(format!("temperature must be >= 0, got {t}")));
98 }
99 if self.max_tokens == Some(0) {
100 return Err(invalid("max_tokens must be >= 1".into()));
101 }
102 match self.sampling {
103 Some(Sampling::TopK { k: 0, .. }) => {
104 return Err(invalid("top-k sampling needs k >= 1".into()));
105 }
106 Some(Sampling::TopP { p, .. }) if !p.is_finite() || p <= 0.0 || p > 1.0 => {
107 return Err(invalid(format!(
108 "top-p sampling needs p in (0, 1], got {p}"
109 )));
110 }
111 _ => {}
112 }
113
114 Ok(AppleFMClient {
115 config: Arc::new(Config {
116 lora: self.lora,
117 temperature: self.temperature,
118 max_tokens: self.max_tokens,
119 sampling: self.sampling,
120 }),
121 meta: Arc::new(ProviderMeta {
122 description: self.description,
123 ..Default::default()
124 }),
125 session: Arc::new(tokio::sync::Mutex::new(crate::client::Session::default())),
126 })
127 }
128}
129
130fn invalid(message: String) -> ChatFailure {
131 ChatFailure::from_err(ChatError::Provider(format!(
132 "chat-applefm builder: {message}"
133 )))
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn rejects_missing_lora_path() {
142 let err = AppleFMBuilder::new()
143 .with_lora("/definitely/not/here.fmadapter")
144 .build()
145 .unwrap_err();
146 assert!(err.err.to_string().contains("not found"));
147 }
148
149 #[test]
150 fn accepts_existing_lora_path() {
151 let manifest = concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml");
153 assert!(AppleFMBuilder::new().with_lora(manifest).build().is_ok());
154 }
155
156 #[test]
157 fn rejects_bad_sampling_and_options() {
158 assert!(
159 AppleFMBuilder::new()
160 .with_sampling(Sampling::TopP { p: 1.5, seed: None })
161 .build()
162 .is_err()
163 );
164 assert!(
165 AppleFMBuilder::new()
166 .with_sampling(Sampling::TopK { k: 0, seed: None })
167 .build()
168 .is_err()
169 );
170 assert!(
171 AppleFMBuilder::new()
172 .with_temperature(-0.1)
173 .build()
174 .is_err()
175 );
176 assert!(AppleFMBuilder::new().with_max_tokens(0).build().is_err());
177 assert!(AppleFMBuilder::new().build().is_ok());
178 }
179}