1use crate::agent::types::{
2 ChatMessage, InferenceEvent, TokenUsage, ToolCallFn, ToolCallResponse, ToolDefinition,
3};
4use async_trait::async_trait;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::time::Duration;
9use tokio::sync::mpsc;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ProviderResponse {
13 pub content: Option<String>,
14 pub tool_calls: Option<Vec<ToolCallResponse>>,
15 pub usage: TokenUsage,
16 pub finish_reason: Option<String>,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ProviderModelKind {
21 Any,
22 Coding,
23 Embed,
24}
25
26#[async_trait]
27pub trait ModelProvider: Send + Sync {
28 async fn call_with_tools(
29 &self,
30 messages: &[ChatMessage],
31 tools: &[ToolDefinition],
32 model_override: Option<&str>,
33 ) -> Result<ProviderResponse, String>;
34
35 async fn stream(
36 &self,
37 messages: &[ChatMessage],
38 tx: mpsc::Sender<InferenceEvent>,
39 ) -> Result<(), Box<dyn std::error::Error>>;
40
41 async fn health_check(&self) -> bool;
42 async fn detect_model(&self) -> Result<String, String>;
43 async fn detect_context_length(&self) -> usize;
44 async fn load_model(&self, model_id: &str) -> Result<(), String>;
45 async fn load_model_with_context(
46 &self,
47 model_id: &str,
48 context_length: Option<usize>,
49 ) -> Result<(), String>;
50 async fn load_embedding_model(&self, model_id: &str) -> Result<(), String>;
51 async fn list_models(
52 &self,
53 kind: ProviderModelKind,
54 loaded_only: bool,
55 ) -> Result<Vec<String>, String>;
56 async fn unload_model(&self, model_id: Option<&str>, all: bool) -> Result<String, String>;
57 async fn unload_embedding_model(&self, model_id: Option<&str>) -> Result<String, String>;
58 async fn prewarm(&self) -> Result<(), String>;
59
60 async fn get_embedding_model(&self) -> Option<String>;
61
62 fn name(&self) -> &str;
63 fn current_model(&self) -> String;
64 fn context_length(&self) -> usize;
65
66 fn set_runtime_profile(&mut self, model: &str, context_length: usize);
67}
68
69pub struct LmsProvider {
70 pub client: Client,
71 pub api_url: String,
72 pub base_url: String,
73 pub model: String,
74 pub context_length: usize,
75 pub lms: crate::agent::lms::LmsHarness,
76}
77
78fn truncate_provider_error_body(body: &str) -> String {
79 let trimmed = body.trim();
80 if trimmed.is_empty() {
81 return String::new();
82 }
83 let mut chars = trimmed.chars();
84 let compact: String = chars.by_ref().take(240).collect();
85 if chars.next().is_some() {
86 format!("{}...", compact)
87 } else {
88 compact
89 }
90}
91
92fn lms_message_to_json(message: &ChatMessage) -> Value {
93 let content = match &message.content {
94 crate::agent::types::MessageContent::Text(text) => Value::String(text.clone()),
95 crate::agent::types::MessageContent::Parts(parts) => serde_json::to_value(parts)
96 .unwrap_or_else(|_| Value::String(message.content.as_str().to_string())),
97 };
98
99 match message.role.as_str() {
100 "assistant" => {
101 let mut base = serde_json::json!({
102 "role": "assistant",
103 "content": content,
104 });
105 if let Some(calls) = &message.tool_calls {
106 let tool_calls: Vec<Value> = calls
107 .iter()
108 .map(|call| {
109 let arguments = if call.function.arguments.is_string() {
110 call.function.arguments.clone()
111 } else {
112 Value::String(call.function.arguments.to_string())
113 };
114 serde_json::json!({
115 "id": call.id,
116 "type": call.call_type,
117 "function": {
118 "name": call.function.name,
119 "arguments": arguments,
120 }
121 })
122 })
123 .collect();
124 if let Some(obj) = base.as_object_mut() {
125 obj.insert("tool_calls".to_string(), Value::Array(tool_calls));
126 }
127 }
128 base
129 }
130 "tool" => serde_json::json!({
131 "role": "tool",
132 "content": content,
133 "tool_call_id": message.tool_call_id.clone().unwrap_or_default(),
134 }),
135 _ => serde_json::json!({
136 "role": message.role,
137 "content": content,
138 }),
139 }
140}
141
142fn lms_messages_payload(messages: &[ChatMessage]) -> Vec<Value> {
143 messages.iter().map(lms_message_to_json).collect()
144}
145
146fn push_unique_model(models: &mut Vec<String>, candidate: &str) {
147 let trimmed = candidate.trim();
148 if !trimmed.is_empty() && !models.iter().any(|existing| existing == trimmed) {
149 models.push(trimmed.to_string());
150 }
151}
152
153fn matches_lms_model_kind(kind: ProviderModelKind, raw_type: &str) -> bool {
154 match kind {
155 ProviderModelKind::Any => true,
156 ProviderModelKind::Coding => raw_type != "embedding" && raw_type != "embeddings",
157 ProviderModelKind::Embed => raw_type == "embedding" || raw_type == "embeddings",
158 }
159}
160
161fn looks_like_embedding_model_name(name: &str) -> bool {
162 let lower = name.to_ascii_lowercase();
163 lower.contains("embed")
164 || lower.contains("embedding")
165 || lower.contains("minilm")
166 || lower.contains("bge")
167 || lower.contains("e5")
168}
169
170#[async_trait]
171impl ModelProvider for LmsProvider {
172 async fn call_with_tools(
173 &self,
174 messages: &[ChatMessage],
175 tools: &[ToolDefinition],
176 model_override: Option<&str>,
177 ) -> Result<ProviderResponse, String> {
178 let model = model_override.unwrap_or(&self.model).to_string();
179 let payload_messages = lms_messages_payload(messages);
180 let request = serde_json::json!({
181 "model": model,
182 "messages": payload_messages,
183 "temperature": 0.2,
184 "stream": false,
185 "tools": if tools.is_empty() { None } else { Some(tools) },
186 });
187
188 let mut last_err = String::new();
189 for attempt in 0..3u32 {
190 match self.client.post(&self.api_url).json(&request).send().await {
191 Ok(res) if res.status().is_success() => {
192 let body: Value = res
193 .json()
194 .await
195 .map_err(|e| format!("LMS parse error: {}", e))?;
196 let choice = body["choices"].get(0).ok_or("Empty choice from LMS")?;
197 let message = &choice["message"];
198 let content = message["content"].as_str().map(|s| s.to_string());
199 let tool_calls: Option<Vec<ToolCallResponse>> =
200 serde_json::from_value(message["tool_calls"].clone()).ok();
201 let usage: TokenUsage =
202 serde_json::from_value(body["usage"].clone()).unwrap_or_default();
203 let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
204 return Ok(ProviderResponse {
205 content,
206 tool_calls,
207 usage,
208 finish_reason,
209 });
210 }
211 Ok(res) => {
212 let status = res.status();
213 let body = res.text().await.unwrap_or_default();
214 let body_note = truncate_provider_error_body(&body);
215 last_err = if body_note.is_empty() {
216 format!("HTTP {}", status)
217 } else {
218 format!("HTTP {} | {}", status, body_note)
219 };
220 }
221 Err(e) => {
222 last_err = e.to_string();
223 }
224 }
225 if attempt < 2 {
226 tokio::time::sleep(Duration::from_millis(500)).await;
227 }
228 }
229 Err(format!("LMS unreachable: {}", last_err))
230 }
231
232 async fn stream(
233 &self,
234 messages: &[ChatMessage],
235 tx: mpsc::Sender<InferenceEvent>,
236 ) -> Result<(), Box<dyn std::error::Error>> {
237 let request = serde_json::json!({
238 "model": self.model,
239 "messages": messages,
240 "temperature": 0.2,
241 "stream": true,
242 });
243
244 let res = self
245 .client
246 .post(&self.api_url)
247 .json(&request)
248 .send()
249 .await?;
250 if !res.status().is_success() {
251 return Err(format!("LMS stream error: {}", res.status()).into());
252 }
253
254 use futures::StreamExt;
255 let mut stream = res.bytes_stream();
256 while let Some(chunk) = stream.next().await {
257 let chunk = chunk?;
258 let text = String::from_utf8_lossy(&chunk);
259 for line in text.lines() {
260 if let Some(data) = line.strip_prefix("data: ") {
261 if data == "[DONE]" {
262 break;
263 }
264 if let Ok(v) = serde_json::from_str::<Value>(data) {
265 if let Some(delta) = v["choices"][0]["delta"]["content"].as_str() {
266 let _ = tx.send(InferenceEvent::Token(delta.to_string())).await;
267 }
268 }
269 }
270 }
271 }
272 let _ = tx.send(InferenceEvent::Done).await;
273 Ok(())
274 }
275
276 async fn health_check(&self) -> bool {
277 if self.lms.is_server_responding(&self.base_url).await {
278 return true;
279 }
280 if self.lms.binary_path.is_some() {
281 let _ = self.lms.ensure_server_running();
282 tokio::time::sleep(Duration::from_millis(1500)).await;
283 return self.lms.is_server_responding(&self.base_url).await;
284 }
285 false
286 }
287
288 async fn detect_model(&self) -> Result<String, String> {
289 let base = self.base_url.trim_end_matches('/').trim_end_matches("/v1");
290 let url = format!("{}/api/v0/models", base);
291 if let Ok(res) = self.client.get(&url).send().await {
292 if res.status().is_success() {
293 let body: Value = res.json().await.map_err(|e| e.to_string())?;
294 if let Some(data) = body["data"].as_array() {
295 for m in data {
296 let m_type = m["type"].as_str().unwrap_or_default();
297 if (m_type == "chat" || m_type == "vlm" || m_type == "llm")
298 && m["state"].as_str() == Some("loaded")
299 {
300 return Ok(m["id"].as_str().unwrap_or_default().to_string());
301 }
302 }
303 }
304 }
305 }
306 let url_v1 = format!("{}/v1/models", base);
307 let resp_v1 = self
308 .client
309 .get(&url_v1)
310 .send()
311 .await
312 .map_err(|e| e.to_string())?;
313 let body_v1: Value = resp_v1.json().await.map_err(|e| e.to_string())?;
314 if let Some(data) = body_v1["data"].as_array() {
315 if let Some(first) = data.iter().find(|m| {
316 !m["id"]
317 .as_str()
318 .unwrap_or_default()
319 .to_lowercase()
320 .contains("embed")
321 }) {
322 return Ok(first["id"].as_str().unwrap_or_default().to_string());
323 }
324 }
325 Ok(String::new())
326 }
327
328 async fn detect_context_length(&self) -> usize {
329 let base = self.base_url.trim_end_matches('/').trim_end_matches("/v1");
330 let url = format!("{}/api/v0/models", base);
331 if let Ok(res) = self.client.get(&url).send().await {
332 if res.status().is_success() {
333 let body: Value = res.json().await.unwrap_or_default();
334 if let Some(data) = body["data"].as_array() {
335 for m in data {
336 let m_type = m["type"].as_str().unwrap_or_default();
337 if (m_type == "chat" || m_type == "vlm" || m_type == "llm")
338 && m["state"].as_str() == Some("loaded")
339 {
340 let fields = [
342 "loaded_context_length",
343 "context_length",
344 "max_context_length",
345 "contextLength",
346 ];
347
348 for field in fields {
350 if let Some(val) = m.get(field) {
351 if let Some(len) = val.as_u64() {
352 return len as usize;
353 }
354 if let Some(s) = val.as_str() {
355 if let Ok(len) = s.parse::<usize>() {
356 return len;
357 }
358 }
359 }
360 }
361
362 if let Some(stats) = m.get("stats") {
364 for field in fields {
365 if let Some(val) = stats.get(field) {
366 if let Some(len) = val.as_u64() {
367 return len as usize;
368 }
369 if let Some(s) = val.as_str() {
370 if let Ok(len) = s.parse::<usize>() {
371 return len;
372 }
373 }
374 }
375 }
376 }
377
378 if let Some(config) = m.get("config") {
380 for field in fields {
381 if let Some(val) = config.get(field) {
382 if let Some(len) = val.as_u64() {
383 return len as usize;
384 }
385 if let Some(s) = val.as_str() {
386 if let Ok(len) = s.parse::<usize>() {
387 return len;
388 }
389 }
390 }
391 }
392 }
393 }
394 }
395 }
396 }
397 }
398 0
399 }
400
401 async fn load_model(&self, model_id: &str) -> Result<(), String> {
402 self.load_model_with_context(model_id, None).await
403 }
404
405 async fn load_model_with_context(
406 &self,
407 model_id: &str,
408 context_length: Option<usize>,
409 ) -> Result<(), String> {
410 let mut payload = serde_json::json!({ "model": model_id });
411 if let Some(ctx) = context_length {
412 payload["context_length"] = serde_json::json!(ctx);
413 }
414
415 let load_url = format!("{}/api/v1/models/load", self.base_url);
416 if let Ok(res) = self.client.post(&load_url).json(&payload).send().await {
417 if res.status().is_success() {
418 return Ok(());
419 }
420 let body = res.text().await.unwrap_or_default();
421 let body_note = truncate_provider_error_body(&body);
422 if !body_note.is_empty() {
423 return Err(format!("Model load failed: {}", body_note));
424 }
425 }
426
427 if context_length.is_none()
428 && self.lms.binary_path.is_some()
429 && self.lms.load_model(model_id).is_ok()
430 {
431 return Ok(());
432 }
433
434 let payload = serde_json::json!({
435 "model": model_id,
436 "messages": [{"role": "system", "content": "System boot"}],
437 "max_tokens": 1,
438 "stream": false
439 });
440 match self.client.post(&self.api_url).json(&payload).send().await {
441 Ok(res) if res.status().is_success() => Ok(()),
442 Ok(res) => Err(format!("Model load failed: HTTP {}", res.status())),
443 Err(e) => Err(format!("Model load failed: {}", e)),
444 }
445 }
446
447 async fn load_embedding_model(&self, model_id: &str) -> Result<(), String> {
448 self.load_model(model_id).await
449 }
450
451 async fn list_models(
452 &self,
453 kind: ProviderModelKind,
454 loaded_only: bool,
455 ) -> Result<Vec<String>, String> {
456 let mut models = Vec::new();
457
458 if loaded_only {
459 let url = format!("{}/api/v0/models", self.base_url);
460 if let Ok(res) = self.client.get(&url).send().await {
461 if res.status().is_success() {
462 let body: Value = res.json().await.map_err(|e| e.to_string())?;
463 if let Some(data) = body["data"].as_array() {
464 for model in data {
465 if model["state"].as_str() != Some("loaded") {
466 continue;
467 }
468 let raw_type = model["type"].as_str().unwrap_or_default();
469 if !matches_lms_model_kind(kind, raw_type) {
470 continue;
471 }
472 if let Some(id) = model["id"].as_str() {
473 push_unique_model(&mut models, id);
474 }
475 }
476 }
477 }
478 }
479
480 if models.is_empty()
481 && self.lms.binary_path.is_some()
482 && kind != ProviderModelKind::Embed
483 {
484 if let Ok(cli_models) = self.lms.list_loaded_models() {
485 for model in cli_models {
486 push_unique_model(&mut models, &model);
487 }
488 }
489 }
490 return Ok(models);
491 }
492
493 let url = format!("{}/api/v1/models", self.base_url);
494 if let Ok(res) = self.client.get(&url).send().await {
495 if res.status().is_success() {
496 let body: Value = res.json().await.map_err(|e| e.to_string())?;
497 if let Some(data) = body["data"].as_array() {
498 for model in data {
499 let raw_type = model["type"].as_str().unwrap_or_default();
500 if !matches_lms_model_kind(kind, raw_type) {
501 continue;
502 }
503 if let Some(id) = model["id"].as_str() {
504 push_unique_model(&mut models, id);
505 }
506 }
507 }
508 }
509 }
510
511 if models.is_empty() && self.lms.binary_path.is_some() && kind != ProviderModelKind::Embed {
512 if let Ok(cli_models) = self.lms.list_models() {
513 for model in cli_models {
514 push_unique_model(&mut models, &model);
515 }
516 }
517 }
518
519 Ok(models)
520 }
521
522 async fn unload_model(&self, model_id: Option<&str>, all: bool) -> Result<String, String> {
523 if all {
524 let loaded = self.list_models(ProviderModelKind::Any, true).await?;
525 if loaded.is_empty() {
526 return Ok("No LM Studio models are currently loaded.".to_string());
527 }
528
529 if self.lms.binary_path.is_some() && self.lms.unload_all_models().is_ok() {
530 return Ok(format!("Unloaded {} LM Studio model(s).", loaded.len()));
531 }
532
533 let unload_url = format!("{}/api/v1/models/unload", self.base_url);
534 let mut unloaded = 0usize;
535 let mut failures = Vec::new();
536 for instance_id in loaded {
537 match self
538 .client
539 .post(&unload_url)
540 .json(&serde_json::json!({ "instance_id": instance_id }))
541 .send()
542 .await
543 {
544 Ok(res) if res.status().is_success() => unloaded += 1,
545 Ok(res) => failures.push(format!("{} ({})", instance_id, res.status())),
546 Err(e) => failures.push(format!("{} ({})", instance_id, e)),
547 }
548 }
549 if failures.is_empty() {
550 return Ok(format!("Unloaded {} LM Studio model(s).", unloaded));
551 }
552 return Err(format!(
553 "Unloaded {} LM Studio model(s), but some unloads failed: {}",
554 unloaded,
555 failures.join(", ")
556 ));
557 }
558
559 let target = model_id
560 .map(str::trim)
561 .filter(|value| !value.is_empty())
562 .ok_or_else(|| "Missing model ID to unload.".to_string())?;
563
564 let unload_url = format!("{}/api/v1/models/unload", self.base_url);
565 match self
566 .client
567 .post(&unload_url)
568 .json(&serde_json::json!({ "instance_id": target }))
569 .send()
570 .await
571 {
572 Ok(res) if res.status().is_success() => {
573 Ok(format!("Unloaded LM Studio model `{}`.", target))
574 }
575 Ok(res) => {
576 let status = res.status();
577 let body = res.text().await.unwrap_or_default();
578 let body_note = truncate_provider_error_body(&body);
579 if self.lms.binary_path.is_some() && self.lms.unload_model(target).is_ok() {
580 Ok(format!("Unloaded LM Studio model `{}`.", target))
581 } else if body_note.is_empty() {
582 Err(format!("LM Studio unload failed: HTTP {}", status))
583 } else {
584 Err(format!(
585 "LM Studio unload failed: HTTP {} | {}",
586 status, body_note
587 ))
588 }
589 }
590 Err(err) => {
591 if self.lms.binary_path.is_some() && self.lms.unload_model(target).is_ok() {
592 Ok(format!("Unloaded LM Studio model `{}`.", target))
593 } else {
594 Err(format!("LM Studio unload failed: {}", err))
595 }
596 }
597 }
598 }
599
600 async fn unload_embedding_model(&self, model_id: Option<&str>) -> Result<String, String> {
601 self.unload_model(model_id, false).await
602 }
603
604 async fn prewarm(&self) -> Result<(), String> {
605 let payload = serde_json::json!({
606 "model": self.model,
607 "messages": [{"role": "system", "content": "Hematite BootSequence"}],
608 "max_tokens": 1,
609 "stream": false
610 });
611 let _ = self.client.post(&self.api_url).json(&payload).send().await;
612 Ok(())
613 }
614
615 async fn get_embedding_model(&self) -> Option<String> {
616 let url = format!("{}/api/v0/models", self.base_url);
617 if let Ok(res) = self.client.get(&url).send().await {
618 if let Ok(body) = res.json::<Value>().await {
619 if let Some(data) = body["data"].as_array() {
620 return data
621 .iter()
622 .find(|m| {
623 m["type"].as_str() == Some("embeddings")
624 && m["state"].as_str() == Some("loaded")
625 })
626 .map(|m| m["id"].as_str().unwrap_or_default().to_string());
627 }
628 }
629 }
630 None
631 }
632
633 fn name(&self) -> &str {
634 "LM Studio"
635 }
636 fn current_model(&self) -> String {
637 self.model.clone()
638 }
639 fn context_length(&self) -> usize {
640 self.context_length
641 }
642 fn set_runtime_profile(&mut self, model: &str, context_length: usize) {
643 self.model = model.to_string();
644 self.context_length = context_length;
645 }
646}
647
648pub struct OllamaProvider {
649 pub client: Client,
650 pub base_url: String,
651 pub model: String,
652 pub context_length: usize,
653 pub embed_model: std::sync::Arc<std::sync::RwLock<Option<String>>>,
654 pub ollama: crate::agent::ollama::OllamaHarness,
655}
656
657#[async_trait]
658impl ModelProvider for OllamaProvider {
659 async fn call_with_tools(
660 &self,
661 messages: &[ChatMessage],
662 tools: &[ToolDefinition],
663 model_override: Option<&str>,
664 ) -> Result<ProviderResponse, String> {
665 let model = model_override.unwrap_or(&self.model).to_string();
666 let url = format!("{}/api/chat", self.base_url);
667 let request = serde_json::json!({
668 "model": model, "messages": messages, "stream": false,
669 "tools": if tools.is_empty() { None } else { Some(tools) },
670 });
671 let res = self
672 .client
673 .post(&url)
674 .json(&request)
675 .send()
676 .await
677 .map_err(|e| e.to_string())?;
678 if !res.status().is_success() {
679 return Err(format!("Ollama error: {}", res.status()));
680 }
681 let body: Value = res.json().await.map_err(|e| e.to_string())?;
682 let message = &body["message"];
683 let content = message["content"].as_str().map(|s| s.to_string());
684 let tool_calls = if let Some(calls) = message["tool_calls"].as_array() {
685 let mut mapped = Vec::new();
686 for (i, c) in calls.iter().enumerate() {
687 mapped.push(ToolCallResponse {
688 id: format!("call_{}", i),
689 call_type: "function".to_string(),
690 function: ToolCallFn {
691 name: c["function"]["name"]
692 .as_str()
693 .unwrap_or_default()
694 .to_string(),
695 arguments: c["function"]["arguments"].clone(),
696 },
697 index: Some(i as i32),
698 });
699 }
700 Some(mapped)
701 } else {
702 None
703 };
704 let usage = TokenUsage {
705 prompt_tokens: body["prompt_eval_count"].as_u64().unwrap_or(0) as usize,
706 completion_tokens: body["eval_count"].as_u64().unwrap_or(0) as usize,
707 ..Default::default()
708 };
709 Ok(ProviderResponse {
710 content,
711 tool_calls,
712 usage,
713 finish_reason: Some("stop".to_string()),
714 })
715 }
716
717 async fn stream(
718 &self,
719 messages: &[ChatMessage],
720 tx: mpsc::Sender<InferenceEvent>,
721 ) -> Result<(), Box<dyn std::error::Error>> {
722 let url = format!("{}/api/chat", self.base_url);
723 let request =
724 serde_json::json!({ "model": self.model, "messages": messages, "stream": true });
725 let res = self.client.post(&url).json(&request).send().await?;
726 use futures::StreamExt;
727 let mut stream = res.bytes_stream();
728 while let Some(chunk) = stream.next().await {
729 let chunk = chunk?;
730 if let Ok(v) = serde_json::from_slice::<Value>(&chunk) {
731 if let Some(delta) = v["message"]["content"].as_str() {
732 let _ = tx.send(InferenceEvent::Token(delta.to_string())).await;
733 }
734 if v["done"].as_bool().unwrap_or(false) {
735 break;
736 }
737 }
738 }
739 let _ = tx.send(InferenceEvent::Done).await;
740 Ok(())
741 }
742
743 async fn health_check(&self) -> bool {
744 self.ollama.is_reachable().await
745 }
746 async fn detect_model(&self) -> Result<String, String> {
747 let running_url = format!("{}/api/ps", self.base_url);
748 if let Ok(resp) = self.client.get(&running_url).send().await {
749 let body: Value = resp.json().await.map_err(|e| e.to_string())?;
750 if let Some(models) = body["models"].as_array() {
751 if let Some(first) = models.first() {
752 let name = first["name"]
753 .as_str()
754 .or_else(|| first["model"].as_str())
755 .unwrap_or_default();
756 return Ok(name.to_string());
757 }
758 return Ok(String::new());
759 }
760 }
761
762 if !self.model.trim().is_empty() {
763 return Ok(self.model.clone());
764 }
765
766 let url = format!("{}/api/tags", self.base_url);
767 let resp = self
768 .client
769 .get(&url)
770 .send()
771 .await
772 .map_err(|e| e.to_string())?;
773 let body: Value = resp.json().await.map_err(|e| e.to_string())?;
774 if let Some(models) = body["models"].as_array() {
775 if let Some(first) = models.first() {
776 return Ok(first["name"].as_str().unwrap_or_default().to_string());
777 }
778 }
779 Ok(String::new())
780 }
781 async fn detect_context_length(&self) -> usize {
782 let running_url = format!("{}/api/ps", self.base_url);
783 if let Ok(resp) = self.client.get(&running_url).send().await {
784 if let Ok(body) = resp.json::<Value>().await {
785 if let Some(models) = body["models"].as_array() {
786 if let Some(first) = models.first() {
787 if let Some(context_length) = first["context_length"].as_u64() {
788 return context_length as usize;
789 }
790 }
791 }
792 }
793 }
794 self.context_length
795 }
796 async fn load_model(&self, _model_id: &str) -> Result<(), String> {
797 self.load_model_with_context(_model_id, None).await
798 }
799 async fn load_model_with_context(
800 &self,
801 model_id: &str,
802 context_length: Option<usize>,
803 ) -> Result<(), String> {
804 if !self.ollama.has_model(model_id).await? {
805 return Err(format!(
806 "Ollama model `{}` is not pulled locally. Run `ollama pull {}` first.",
807 model_id, model_id
808 ));
809 }
810 let url = format!("{}/api/generate", self.base_url);
811 let request = serde_json::json!({
812 "model": model_id,
813 "prompt": "Hematite runtime warmup",
814 "stream": false,
815 "keep_alive": "30m",
816 "options": {
817 "num_ctx": context_length.unwrap_or(self.context_length.max(4096))
818 }
819 });
820 let res = self
821 .client
822 .post(&url)
823 .json(&request)
824 .send()
825 .await
826 .map_err(|e| e.to_string())?;
827 let status = res.status();
828 if status.is_success() {
829 Ok(())
830 } else {
831 let body = res.text().await.unwrap_or_default();
832 let body_note = truncate_provider_error_body(&body);
833 if body_note.is_empty() {
834 Err(format!("Ollama load failed: HTTP {}", status))
835 } else {
836 Err(format!(
837 "Ollama load failed: HTTP {} | {}",
838 status, body_note
839 ))
840 }
841 }
842 }
843 async fn load_embedding_model(&self, model_id: &str) -> Result<(), String> {
844 if !self.ollama.has_model(model_id).await? {
845 return Err(format!(
846 "Ollama embedding model `{}` is not pulled locally. Run `ollama pull {}` first.",
847 model_id, model_id
848 ));
849 }
850 let url = format!("{}/api/embed", self.base_url);
851 let request = serde_json::json!({
852 "model": model_id,
853 "input": "search_document: Hematite semantic search warmup",
854 "keep_alive": "30m"
855 });
856 let res = self
857 .client
858 .post(&url)
859 .json(&request)
860 .send()
861 .await
862 .map_err(|e| e.to_string())?;
863 let status = res.status();
864 if !status.is_success() {
865 let body = res.text().await.unwrap_or_default();
866 let body_note = truncate_provider_error_body(&body);
867 return if body_note.is_empty() {
868 Err(format!("Ollama embed load failed: HTTP {}", status))
869 } else {
870 Err(format!(
871 "Ollama embed load failed: HTTP {} | {}",
872 status, body_note
873 ))
874 };
875 }
876 if let Ok(mut guard) = self.embed_model.write() {
877 *guard = Some(model_id.to_string());
878 }
879 Ok(())
880 }
881 async fn list_models(
882 &self,
883 kind: ProviderModelKind,
884 loaded_only: bool,
885 ) -> Result<Vec<String>, String> {
886 let url = if loaded_only {
887 format!("{}/api/ps", self.base_url)
888 } else {
889 format!("{}/api/tags", self.base_url)
890 };
891 let resp = self
892 .client
893 .get(&url)
894 .send()
895 .await
896 .map_err(|e| e.to_string())?;
897 let body: Value = resp.json().await.map_err(|e| e.to_string())?;
898 let mut models = Vec::new();
899 if let Some(entries) = body["models"].as_array() {
900 for entry in entries {
901 let name = entry["name"]
902 .as_str()
903 .or_else(|| entry["model"].as_str())
904 .unwrap_or_default();
905 if kind == ProviderModelKind::Embed && !looks_like_embedding_model_name(name) {
906 continue;
907 }
908 if kind == ProviderModelKind::Coding && looks_like_embedding_model_name(name) {
909 continue;
910 }
911 push_unique_model(&mut models, name);
912 }
913 }
914 if loaded_only && kind == ProviderModelKind::Embed {
915 if let Ok(guard) = self.embed_model.read() {
916 if let Some(model) = guard.as_deref() {
917 push_unique_model(&mut models, model);
918 }
919 }
920 }
921 Ok(models)
922 }
923 async fn unload_model(&self, model_id: Option<&str>, all: bool) -> Result<String, String> {
924 let targets = if all {
925 self.list_models(ProviderModelKind::Coding, true).await?
926 } else {
927 vec![model_id
928 .map(str::trim)
929 .filter(|value| !value.is_empty())
930 .ok_or_else(|| "Missing model ID to unload.".to_string())?
931 .to_string()]
932 };
933
934 if targets.is_empty() {
935 return Ok("No Ollama models are currently loaded.".to_string());
936 }
937
938 let url = format!("{}/api/generate", self.base_url);
939 let mut unloaded = 0usize;
940 let mut failures = Vec::new();
941 for target in targets {
942 let request = serde_json::json!({
943 "model": target,
944 "prompt": "",
945 "stream": false,
946 "keep_alive": 0
947 });
948 match self.client.post(&url).json(&request).send().await {
949 Ok(res) if res.status().is_success() => unloaded += 1,
950 Ok(res) => failures.push(format!("{} ({})", target, res.status())),
951 Err(e) => failures.push(format!("{} ({})", target, e)),
952 }
953 }
954
955 if failures.is_empty() {
956 return Ok(if all {
957 format!("Unloaded {} Ollama model(s).", unloaded)
958 } else {
959 format!("Unloaded Ollama model `{}`.", model_id.unwrap_or_default())
960 });
961 }
962
963 Err(format!(
964 "Unloaded {} Ollama model(s), but some unloads failed: {}",
965 unloaded,
966 failures.join(", ")
967 ))
968 }
969 async fn unload_embedding_model(&self, model_id: Option<&str>) -> Result<String, String> {
970 let target = match model_id {
971 Some(explicit) if !explicit.trim().is_empty() => explicit.trim().to_string(),
972 _ => self
973 .get_embedding_model()
974 .await
975 .ok_or_else(|| "No Ollama embedding model is currently loaded.".to_string())?,
976 };
977 let url = format!("{}/api/embed", self.base_url);
978 let request = serde_json::json!({
979 "model": target,
980 "input": "search_document: Hematite semantic search warmup",
981 "keep_alive": 0
982 });
983 let res = self
984 .client
985 .post(&url)
986 .json(&request)
987 .send()
988 .await
989 .map_err(|e| e.to_string())?;
990 if res.status().is_success() {
991 if let Ok(mut guard) = self.embed_model.write() {
992 if guard.as_deref() == Some(target.as_str()) {
993 *guard = None;
994 }
995 }
996 Ok(format!("Unloaded Ollama embedding model `{}`.", target))
997 } else {
998 let status = res.status();
999 let body = res.text().await.unwrap_or_default();
1000 let body_note = truncate_provider_error_body(&body);
1001 if body_note.is_empty() {
1002 Err(format!("Ollama embed unload failed: HTTP {}", status))
1003 } else {
1004 Err(format!(
1005 "Ollama embed unload failed: HTTP {} | {}",
1006 status, body_note
1007 ))
1008 }
1009 }
1010 }
1011 async fn prewarm(&self) -> Result<(), String> {
1012 Ok(())
1013 }
1014 async fn get_embedding_model(&self) -> Option<String> {
1015 if let Ok(guard) = self.embed_model.read() {
1016 if let Some(model) = guard.as_ref() {
1017 return Some(model.clone());
1018 }
1019 }
1020
1021 let url = format!("{}/api/ps", self.base_url);
1022 if let Ok(res) = self.client.get(&url).send().await {
1023 if let Ok(body) = res.json::<Value>().await {
1024 if let Some(entries) = body["models"].as_array() {
1025 for entry in entries {
1026 let name = entry["name"]
1027 .as_str()
1028 .or_else(|| entry["model"].as_str())
1029 .unwrap_or_default();
1030 if looks_like_embedding_model_name(name) {
1031 return Some(name.to_string());
1032 }
1033 }
1034 }
1035 }
1036 }
1037 None
1038 }
1039
1040 fn name(&self) -> &str {
1041 "Ollama"
1042 }
1043 fn current_model(&self) -> String {
1044 self.model.clone()
1045 }
1046 fn context_length(&self) -> usize {
1047 self.context_length
1048 }
1049 fn set_runtime_profile(&mut self, model: &str, context_length: usize) {
1050 self.model = model.to_string();
1051 self.context_length = context_length;
1052 }
1053}
1054
1055#[cfg(test)]
1056mod tests {
1057 use super::{
1058 lms_messages_payload, looks_like_embedding_model_name, matches_lms_model_kind,
1059 ProviderModelKind,
1060 };
1061 use crate::agent::types::{ChatMessage, ToolCallFn, ToolCallResponse};
1062 use serde_json::json;
1063
1064 #[test]
1065 fn lms_payload_stringifies_assistant_tool_arguments() {
1066 let messages = vec![ChatMessage::assistant_tool_calls(
1067 "",
1068 vec![ToolCallResponse {
1069 id: "call_1".to_string(),
1070 call_type: "function".to_string(),
1071 function: ToolCallFn {
1072 name: "read_file".to_string(),
1073 arguments: json!({"path":"index.html"}),
1074 },
1075 index: None,
1076 }],
1077 )];
1078
1079 let payload = lms_messages_payload(&messages);
1080 let args = &payload[0]["tool_calls"][0]["function"]["arguments"];
1081 assert!(args.is_string());
1082 assert_eq!(
1083 args.as_str().unwrap_or_default(),
1084 "{\"path\":\"index.html\"}"
1085 );
1086 }
1087
1088 #[test]
1089 fn lms_model_kind_matching_distinguishes_embedding_models() {
1090 assert!(matches_lms_model_kind(ProviderModelKind::Coding, "chat"));
1091 assert!(matches_lms_model_kind(
1092 ProviderModelKind::Embed,
1093 "embeddings"
1094 ));
1095 assert!(!matches_lms_model_kind(
1096 ProviderModelKind::Coding,
1097 "embeddings"
1098 ));
1099 assert!(!matches_lms_model_kind(ProviderModelKind::Embed, "chat"));
1100 }
1101
1102 #[test]
1103 fn embedding_name_heuristic_catches_common_ollama_embed_models() {
1104 assert!(looks_like_embedding_model_name("embeddinggemma"));
1105 assert!(looks_like_embedding_model_name("qwen3-embedding"));
1106 assert!(looks_like_embedding_model_name("all-minilm"));
1107 assert!(!looks_like_embedding_model_name("qwen3.5:latest"));
1108 }
1109}