1use std::marker::PhantomData;
27
28use chat_completions::{
29 ChatError, CompletionsBuilder, CompletionsClient, Request, ReqwestTransport, Transport,
30 TransportError,
31};
32use serde::Deserialize;
33use serde_json::json;
34
35pub const DEFAULT_OLLAMA_HOST: &str = "http://localhost:11434";
37
38const OLLAMA_HOST_ENV: &str = "OLLAMA_HOST";
39
40pub struct WithoutModel;
41pub struct WithModel;
42
43pub struct OllamaBuilder<M = WithoutModel, T: Transport = ReqwestTransport> {
46 scheme: String,
47 host: String,
48 model: Option<String>,
49 api_key: Option<String>,
50 extra_headers: Vec<(String, String)>,
51 description: Option<String>,
52 transport: Option<T>,
53 _m: PhantomData<M>,
54}
55
56impl Default for OllamaBuilder<WithoutModel, ReqwestTransport> {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl OllamaBuilder<WithoutModel, ReqwestTransport> {
63 pub fn new() -> Self {
65 let host =
66 std::env::var(OLLAMA_HOST_ENV).unwrap_or_else(|_| DEFAULT_OLLAMA_HOST.to_string());
67 Self::with_host(host)
68 }
69
70 pub fn with_host(host: impl AsRef<str>) -> Self {
74 let parsed = url::Url::parse(host.as_ref()).expect("Invalid Ollama host URL");
75 let scheme = parsed.scheme().to_string();
76 let host_port = parsed
77 .host_str()
78 .expect("Ollama host URL missing host")
79 .to_string()
80 + &parsed.port().map(|p| format!(":{p}")).unwrap_or_default();
81
82 Self {
83 scheme,
84 host: host_port,
85 model: None,
86 api_key: None,
87 extra_headers: Vec::new(),
88 description: None,
89 transport: Some(ReqwestTransport::default()),
90 _m: PhantomData,
91 }
92 }
93}
94
95impl<M, T: Transport> OllamaBuilder<M, T> {
96 pub async fn ping(&self) -> Result<(), ChatError> {
103 let transport = self.transport.as_ref().expect("transport set");
104 let req = Request {
105 scheme: self.scheme.clone(),
106 host: self.host.clone(),
107 path: "/api/version".to_string(),
108 headers: vec![("Content-Type".into(), "application/json".into())],
109 body: Vec::new(),
110 };
111 match transport.send(req).await {
112 Ok(_) => Ok(()),
113 Err(e) => Err(map_transport_error(&self.scheme, &self.host, e)),
114 }
115 }
116
117 pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
118 self.api_key = Some(api_key.into());
119 self
120 }
121
122 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
123 self.extra_headers.push((key.into(), value.into()));
124 self
125 }
126
127 pub fn with_description(mut self, description: impl Into<String>) -> Self {
128 self.description = Some(description.into());
129 self
130 }
131
132 pub fn with_transport<T2: Transport>(self, transport: T2) -> OllamaBuilder<M, T2> {
133 OllamaBuilder {
134 scheme: self.scheme,
135 host: self.host,
136 model: self.model,
137 api_key: self.api_key,
138 extra_headers: self.extra_headers,
139 description: self.description,
140 transport: Some(transport),
141 _m: PhantomData,
142 }
143 }
144}
145
146impl<T: Transport> OllamaBuilder<WithoutModel, T> {
147 pub fn with_model(self, model: impl Into<String>) -> OllamaBuilder<WithModel, T> {
148 OllamaBuilder {
149 scheme: self.scheme,
150 host: self.host,
151 model: Some(model.into()),
152 api_key: self.api_key,
153 extra_headers: self.extra_headers,
154 description: self.description,
155 transport: self.transport,
156 _m: PhantomData,
157 }
158 }
159}
160
161impl<T: Transport> OllamaBuilder<WithModel, T> {
162 pub fn build(self) -> CompletionsClient<T> {
164 let transport = self.transport.expect("transport set");
165 let model = self.model.expect("model set");
166
167 let base_url = format!("{}://{}/v1", self.scheme, self.host);
168 let mut b = CompletionsBuilder::new()
169 .with_base_url(base_url)
170 .with_model(model)
171 .with_transport(transport);
172
173 if let Some(key) = self.api_key {
174 b = b.with_api_key(key);
175 }
176 for (k, v) in self.extra_headers {
177 b = b.with_header(k, v);
178 }
179 if let Some(desc) = self.description {
180 b = b.with_description(desc);
181 }
182 b.build()
183 }
184
185 pub async fn pull(self) -> Result<Self, ChatError> {
204 let model = self.model.as_ref().expect("model set");
205 let transport = self.transport.as_ref().expect("transport set");
206
207 let body = serde_json::to_vec(&json!({
208 "model": model,
209 "stream": false,
210 }))
211 .map_err(|e| ChatError::Other(e.to_string()))?;
212
213 let mut headers = vec![("Content-Type".into(), "application/json".into())];
214 if let Some(key) = &self.api_key {
215 headers.push(("Authorization".into(), format!("Bearer {key}")));
216 }
217 headers.extend(self.extra_headers.iter().cloned());
218
219 let req = Request {
220 scheme: self.scheme.clone(),
221 host: self.host.clone(),
222 path: "/api/pull".to_string(),
223 headers,
224 body,
225 };
226
227 let res = transport
228 .send(req)
229 .await
230 .map_err(|e| map_transport_error(&self.scheme, &self.host, e))?;
231 if !(200..300).contains(&res.status) {
232 let body = String::from_utf8_lossy(&res.body);
233 return Err(ChatError::Provider(format!(
234 "Ollama pull failed (HTTP {}): {body}",
235 res.status
236 )));
237 }
238
239 #[derive(Deserialize)]
240 struct PullResponse {
241 #[serde(default)]
242 status: Option<String>,
243 #[serde(default)]
244 error: Option<String>,
245 }
246
247 let parsed: PullResponse = serde_json::from_slice(&res.body).unwrap_or(PullResponse {
248 status: None,
249 error: None,
250 });
251
252 if let Some(err) = parsed.error {
253 return Err(ChatError::Provider(format!("Ollama pull: {err}")));
254 }
255 if let Some(status) = parsed.status
256 && status != "success"
257 && !status.is_empty()
258 {
259 return Err(ChatError::Provider(format!("Ollama pull status: {status}")));
260 }
261 Ok(self)
262 }
263}
264
265fn map_transport_error(scheme: &str, host: &str, err: TransportError) -> ChatError {
268 match &err {
269 TransportError::Connection(msg) => ChatError::Provider(format!(
270 "Ollama daemon unreachable at {scheme}://{host} ({msg}). \
271 Install from https://ollama.com/download, then run `ollama serve`. \
272 Override the host with OLLAMA_HOST=http://your-host:port."
273 )),
274 _ => ChatError::from(err),
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn parses_host_only() {
284 let b = OllamaBuilder::with_host("http://localhost:11434");
285 assert_eq!(b.scheme, "http");
286 assert_eq!(b.host, "localhost:11434");
287 }
288
289 #[test]
290 fn parses_host_with_v1_path() {
291 let b = OllamaBuilder::with_host("http://localhost:11434/v1");
292 assert_eq!(b.scheme, "http");
293 assert_eq!(b.host, "localhost:11434");
294 }
295
296 #[test]
297 fn parses_remote_host() {
298 let b = OllamaBuilder::with_host("https://my-ollama.example.com:8443");
299 assert_eq!(b.scheme, "https");
300 assert_eq!(b.host, "my-ollama.example.com:8443");
301 }
302}