cortex_llm/
ollama_http.rs1use std::time::Duration;
13
14use async_trait::async_trait;
15use serde::{Deserialize, Serialize};
16
17use crate::adapter::{
18 blake3_hex, BoxStream, LlmAdapter, LlmError, LlmRequest, LlmResponse, StreamChunk,
19};
20use crate::ollama::{validate_endpoint_url, validate_model_ref, OllamaConfig};
21
22#[derive(Debug, Clone)]
24pub struct OllamaHttpAdapter {
25 config: OllamaConfig,
26}
27
28impl OllamaHttpAdapter {
29 pub fn new(config: OllamaConfig) -> Result<Self, LlmError> {
37 validate_endpoint_url(&config.endpoint_url).map_err(|e| match e {
38 LlmError::InvalidRequest(msg) => LlmError::InvalidRequest(msg),
39 other => other,
40 })?;
41 Ok(Self { config })
42 }
43}
44
45#[derive(Debug, Serialize)]
51struct ChatRequest<'a> {
52 model: &'a str,
53 messages: Vec<OllamaMessage<'a>>,
54 stream: bool,
55}
56
57#[derive(Debug, Serialize)]
59struct OllamaMessage<'a> {
60 role: &'a str,
61 content: &'a str,
62}
63
64#[derive(Debug, Deserialize)]
66struct ChatResponse {
67 #[serde(default)]
68 message: MessageField,
69}
70
71#[derive(Debug, Default, Deserialize)]
73struct MessageField {
74 #[serde(default)]
75 content: String,
76}
77
78#[derive(Debug, Deserialize)]
86struct StreamLine {
87 #[serde(default)]
88 message: MessageField,
89 #[serde(default)]
90 done: bool,
91 done_reason: Option<String>,
93}
94
95#[async_trait]
100impl LlmAdapter for OllamaHttpAdapter {
101 fn adapter_id(&self) -> &'static str {
102 "ollama"
103 }
104
105 async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
106 validate_model_ref(&self.config.model)?;
111 let req = LlmRequest { model: self.config.model.clone(), ..req };
112
113 let config = self.config.clone();
114 let timeout_ms = req.timeout_ms;
115
116 let result = tokio::task::spawn_blocking(move || call_ollama(&config, &req, timeout_ms))
117 .await
118 .map_err(|e| LlmError::Transport(format!("spawn_blocking join error: {e}")))?;
119
120 result
121 }
122
123 fn stream_boxed(&self, req: LlmRequest) -> BoxStream<'_> {
135 let req = LlmRequest { model: self.config.model.clone(), ..req };
138 validate_model_ref_and_stream(self.config.clone(), req)
139 }
140}
141
142fn call_ollama(
144 config: &OllamaConfig,
145 req: &LlmRequest,
146 timeout_ms: u64,
147) -> Result<LlmResponse, LlmError> {
148 let url = format!("{}/api/chat", config.endpoint_url);
149
150 let messages: Vec<OllamaMessage<'_>> = req
151 .messages
152 .iter()
153 .map(|m| OllamaMessage {
154 role: m.role.as_str(),
155 content: &m.content,
156 })
157 .collect();
158
159 let ollama_model = req.model.split('@').next().unwrap_or(&req.model);
161
162 let body = ChatRequest {
163 model: ollama_model,
164 messages,
165 stream: false,
166 };
167
168 let timeout = Duration::from_millis(timeout_ms);
169
170 let agent = ureq::AgentBuilder::new().timeout(timeout).build();
171
172 let raw_response = agent
173 .post(&url)
174 .send_json(
175 serde_json::to_value(&body)
176 .map_err(|e| LlmError::Transport(format!("request serialization failed: {e}")))?,
177 )
178 .map_err(|err| map_ureq_error(err, timeout_ms))?;
179
180 let status = raw_response.status();
181 if status != 200 {
182 return Err(LlmError::Transport(format!("HTTP {status}")));
183 }
184
185 let response_text = raw_response
186 .into_string()
187 .map_err(|e| LlmError::Transport(format!("reading response body: {e}")))?;
188
189 const MAX_RESPONSE_BYTES: usize = 16 * 1024 * 1024; if response_text.len() > MAX_RESPONSE_BYTES {
191 return Err(LlmError::Transport(format!(
192 "ollama response body exceeds 16 MiB limit ({} bytes); refusing to store",
193 response_text.len()
194 )));
195 }
196
197 let parsed: ChatResponse = serde_json::from_str(&response_text)
198 .map_err(|e| LlmError::Parse(format!("ollama response parse: {e}")))?;
199
200 let text = parsed.message.content;
201 let raw_hash = blake3_hex(response_text.as_bytes());
202
203 Ok(LlmResponse {
204 text,
205 parsed_json: None,
206 model: config.model.clone(),
207 usage: None,
208 raw_hash,
209 })
210}
211
212fn validate_model_ref_and_stream(config: OllamaConfig, req: LlmRequest) -> BoxStream<'static> {
217 Box::pin(async_stream::stream! {
218 if let Err(e) = validate_model_ref(&config.model) {
221 yield Err(e);
222 return;
223 }
224
225 let timeout_ms = req.timeout_ms;
226 let result = tokio::task::spawn_blocking(move || {
227 call_ollama_streaming(&config, &req, timeout_ms)
228 })
229 .await;
230
231 match result {
232 Ok(chunks) => {
233 for chunk in chunks {
234 yield chunk;
235 }
236 }
237 Err(e) => yield Err(LlmError::Transport(format!("spawn_blocking join error: {e}"))),
238 }
239 })
240}
241
242fn call_ollama_streaming(
249 config: &OllamaConfig,
250 req: &LlmRequest,
251 timeout_ms: u64,
252) -> Vec<Result<StreamChunk, LlmError>> {
253 let url = format!("{}/api/chat", config.endpoint_url);
254
255 let messages: Vec<OllamaMessage<'_>> = req
256 .messages
257 .iter()
258 .map(|m| OllamaMessage {
259 role: m.role.as_str(),
260 content: &m.content,
261 })
262 .collect();
263
264 let ollama_model = req.model.split('@').next().unwrap_or(&req.model);
266
267 let body = ChatRequest {
268 model: ollama_model,
269 messages,
270 stream: true,
271 };
272
273 let timeout = Duration::from_millis(timeout_ms);
274 let agent = ureq::AgentBuilder::new().timeout(timeout).build();
275
276 let body_value = match serde_json::to_value(&body) {
277 Ok(v) => v,
278 Err(e) => {
279 return vec![Err(LlmError::Transport(format!(
280 "request serialization failed: {e}"
281 )))]
282 }
283 };
284
285 let raw_response = match agent.post(&url).send_json(body_value) {
286 Ok(r) => r,
287 Err(err) => return vec![Err(map_ureq_error(err, timeout_ms))],
288 };
289
290 let status = raw_response.status();
291 if status != 200 {
292 return vec![Err(LlmError::Transport(format!("HTTP {status}")))];
293 }
294
295 let body_text = match raw_response.into_string() {
296 Ok(s) => s,
297 Err(e) => {
298 return vec![Err(LlmError::Transport(format!(
299 "reading streaming response body: {e}"
300 )))]
301 }
302 };
303
304 body_text
305 .lines()
306 .filter(|line| !line.trim().is_empty())
307 .map(|line| {
308 let parsed: StreamLine = serde_json::from_str(line)
309 .map_err(|e| LlmError::Parse(format!("ollama stream line parse: {e}")))?;
310 Ok(StreamChunk {
311 delta: parsed.message.content,
312 finish_reason: if parsed.done {
313 parsed.done_reason
314 } else {
315 None
316 },
317 })
318 })
319 .collect()
320}
321
322fn map_ureq_error(err: ureq::Error, timeout_ms: u64) -> LlmError {
324 match err {
325 ureq::Error::Transport(t) => {
326 let msg = t.to_string();
329 if is_timeout_message(&msg) {
330 LlmError::Timeout { timeout_ms }
331 } else {
332 LlmError::Transport(msg)
333 }
334 }
335 ureq::Error::Status(code, _) => LlmError::Transport(format!("HTTP {code}")),
336 }
337}
338
339fn is_timeout_message(msg: &str) -> bool {
341 let lower = msg.to_ascii_lowercase();
342 lower.contains("timed out") || lower.contains("deadline exceeded") || lower.contains("timeout")
343}
344
345use crate::adapter::LlmRole;
350
351impl LlmRole {
352 fn as_str(self) -> &'static str {
354 match self {
355 LlmRole::User => "user",
356 LlmRole::Assistant => "assistant",
357 LlmRole::Tool => "tool",
358 }
359 }
360}