Skip to main content

hematite/agent/
provider.rs

1use crate::agent::types::{
2    ChatMessage, InferenceEvent, TokenUsage, ToolCallFn, ToolCallResponse, ToolDefinition,
3};
4use async_trait::async_trait;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::time::Duration;
9use tokio::sync::mpsc;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ProviderResponse {
13    pub content: Option<String>,
14    pub tool_calls: Option<Vec<ToolCallResponse>>,
15    pub usage: TokenUsage,
16    pub finish_reason: Option<String>,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ProviderModelKind {
21    Any,
22    Coding,
23    Embed,
24}
25
26#[async_trait]
27pub trait ModelProvider: Send + Sync {
28    async fn call_with_tools(
29        &self,
30        messages: &[ChatMessage],
31        tools: &[ToolDefinition],
32        model_override: Option<&str>,
33    ) -> Result<ProviderResponse, String>;
34
35    async fn stream(
36        &self,
37        messages: &[ChatMessage],
38        tx: mpsc::Sender<InferenceEvent>,
39    ) -> Result<(), Box<dyn std::error::Error>>;
40
41    async fn health_check(&self) -> bool;
42    async fn detect_model(&self) -> Result<String, String>;
43    async fn detect_context_length(&self) -> usize;
44    async fn load_model(&self, model_id: &str) -> Result<(), String>;
45    async fn load_model_with_context(
46        &self,
47        model_id: &str,
48        context_length: Option<usize>,
49    ) -> Result<(), String>;
50    async fn load_embedding_model(&self, model_id: &str) -> Result<(), String>;
51    async fn list_models(
52        &self,
53        kind: ProviderModelKind,
54        loaded_only: bool,
55    ) -> Result<Vec<String>, String>;
56    async fn unload_model(&self, model_id: Option<&str>, all: bool) -> Result<String, String>;
57    async fn unload_embedding_model(&self, model_id: Option<&str>) -> Result<String, String>;
58    async fn prewarm(&self) -> Result<(), String>;
59
60    async fn get_embedding_model(&self) -> Option<String>;
61
62    fn name(&self) -> &str;
63    fn current_model(&self) -> String;
64    fn context_length(&self) -> usize;
65
66    fn set_runtime_profile(&mut self, model: &str, context_length: usize);
67}
68
69pub struct LmsProvider {
70    pub client: Client,
71    pub api_url: String,
72    pub base_url: String,
73    pub model: String,
74    pub context_length: usize,
75    pub lms: crate::agent::lms::LmsHarness,
76}
77
78fn truncate_provider_error_body(body: &str) -> String {
79    let trimmed = body.trim();
80    if trimmed.is_empty() {
81        return String::new();
82    }
83    let compact: String = trimmed.chars().take(240).collect();
84    if trimmed.chars().count() > 240 {
85        format!("{}...", compact)
86    } else {
87        compact
88    }
89}
90
91fn lms_message_to_json(message: &ChatMessage) -> Value {
92    let content = match &message.content {
93        crate::agent::types::MessageContent::Text(text) => Value::String(text.clone()),
94        crate::agent::types::MessageContent::Parts(parts) => serde_json::to_value(parts)
95            .unwrap_or_else(|_| Value::String(message.content.as_str().to_string())),
96    };
97
98    match message.role.as_str() {
99        "assistant" => {
100            let mut base = serde_json::json!({
101                "role": "assistant",
102                "content": content,
103            });
104            if let Some(calls) = &message.tool_calls {
105                let tool_calls: Vec<Value> = calls
106                    .iter()
107                    .map(|call| {
108                        let arguments = if call.function.arguments.is_string() {
109                            call.function.arguments.clone()
110                        } else {
111                            Value::String(call.function.arguments.to_string())
112                        };
113                        serde_json::json!({
114                            "id": call.id,
115                            "type": call.call_type,
116                            "function": {
117                                "name": call.function.name,
118                                "arguments": arguments,
119                            }
120                        })
121                    })
122                    .collect();
123                if let Some(obj) = base.as_object_mut() {
124                    obj.insert("tool_calls".to_string(), Value::Array(tool_calls));
125                }
126            }
127            base
128        }
129        "tool" => serde_json::json!({
130            "role": "tool",
131            "content": content,
132            "tool_call_id": message.tool_call_id.clone().unwrap_or_default(),
133        }),
134        _ => serde_json::json!({
135            "role": message.role,
136            "content": content,
137        }),
138    }
139}
140
141fn lms_messages_payload(messages: &[ChatMessage]) -> Vec<Value> {
142    messages.iter().map(lms_message_to_json).collect()
143}
144
145fn push_unique_model(models: &mut Vec<String>, candidate: &str) {
146    let trimmed = candidate.trim();
147    if !trimmed.is_empty() && !models.iter().any(|existing| existing == trimmed) {
148        models.push(trimmed.to_string());
149    }
150}
151
152fn matches_lms_model_kind(kind: ProviderModelKind, raw_type: &str) -> bool {
153    match kind {
154        ProviderModelKind::Any => true,
155        ProviderModelKind::Coding => raw_type != "embedding" && raw_type != "embeddings",
156        ProviderModelKind::Embed => raw_type == "embedding" || raw_type == "embeddings",
157    }
158}
159
160fn looks_like_embedding_model_name(name: &str) -> bool {
161    let lower = name.to_ascii_lowercase();
162    lower.contains("embed")
163        || lower.contains("embedding")
164        || lower.contains("minilm")
165        || lower.contains("bge")
166        || lower.contains("e5")
167}
168
169#[async_trait]
170impl ModelProvider for LmsProvider {
171    async fn call_with_tools(
172        &self,
173        messages: &[ChatMessage],
174        tools: &[ToolDefinition],
175        model_override: Option<&str>,
176    ) -> Result<ProviderResponse, String> {
177        let model = model_override.unwrap_or(&self.model).to_string();
178        let payload_messages = lms_messages_payload(messages);
179        let request = serde_json::json!({
180            "model": model,
181            "messages": payload_messages,
182            "temperature": 0.2,
183            "stream": false,
184            "tools": if tools.is_empty() { None } else { Some(tools) },
185        });
186
187        let mut last_err = String::new();
188        for attempt in 0..3u32 {
189            match self.client.post(&self.api_url).json(&request).send().await {
190                Ok(res) if res.status().is_success() => {
191                    let body: Value = res
192                        .json()
193                        .await
194                        .map_err(|e| format!("LMS parse error: {}", e))?;
195                    let choice = body["choices"].get(0).ok_or("Empty choice from LMS")?;
196                    let message = &choice["message"];
197                    let content = message["content"].as_str().map(|s| s.to_string());
198                    let tool_calls: Option<Vec<ToolCallResponse>> =
199                        serde_json::from_value(message["tool_calls"].clone()).ok();
200                    let usage: TokenUsage =
201                        serde_json::from_value(body["usage"].clone()).unwrap_or_default();
202                    let finish_reason = choice["finish_reason"].as_str().map(|s| s.to_string());
203                    return Ok(ProviderResponse {
204                        content,
205                        tool_calls,
206                        usage,
207                        finish_reason,
208                    });
209                }
210                Ok(res) => {
211                    let status = res.status();
212                    let body = res.text().await.unwrap_or_default();
213                    let body_note = truncate_provider_error_body(&body);
214                    last_err = if body_note.is_empty() {
215                        format!("HTTP {}", status)
216                    } else {
217                        format!("HTTP {} | {}", status, body_note)
218                    };
219                }
220                Err(e) => {
221                    last_err = e.to_string();
222                }
223            }
224            if attempt < 2 {
225                tokio::time::sleep(Duration::from_millis(500)).await;
226            }
227        }
228        Err(format!("LMS unreachable: {}", last_err))
229    }
230
231    async fn stream(
232        &self,
233        messages: &[ChatMessage],
234        tx: mpsc::Sender<InferenceEvent>,
235    ) -> Result<(), Box<dyn std::error::Error>> {
236        let request = serde_json::json!({
237            "model": self.model,
238            "messages": messages,
239            "temperature": 0.2,
240            "stream": true,
241        });
242
243        let res = self
244            .client
245            .post(&self.api_url)
246            .json(&request)
247            .send()
248            .await?;
249        if !res.status().is_success() {
250            return Err(format!("LMS stream error: {}", res.status()).into());
251        }
252
253        use futures::StreamExt;
254        let mut stream = res.bytes_stream();
255        while let Some(chunk) = stream.next().await {
256            let chunk = chunk?;
257            let text = String::from_utf8_lossy(&chunk);
258            for line in text.lines() {
259                if line.starts_with("data: ") {
260                    let data = &line[6..];
261                    if data == "[DONE]" {
262                        break;
263                    }
264                    if let Ok(v) = serde_json::from_str::<Value>(data) {
265                        if let Some(delta) = v["choices"][0]["delta"]["content"].as_str() {
266                            let _ = tx.send(InferenceEvent::Token(delta.to_string())).await;
267                        }
268                    }
269                }
270            }
271        }
272        let _ = tx.send(InferenceEvent::Done).await;
273        Ok(())
274    }
275
276    async fn health_check(&self) -> bool {
277        if self.lms.is_server_responding(&self.base_url).await {
278            return true;
279        }
280        if self.lms.binary_path.is_some() {
281            let _ = self.lms.ensure_server_running();
282            tokio::time::sleep(Duration::from_millis(1500)).await;
283            return self.lms.is_server_responding(&self.base_url).await;
284        }
285        false
286    }
287
288    async fn detect_model(&self) -> Result<String, String> {
289        let base = self.base_url.trim_end_matches('/').trim_end_matches("/v1");
290        let url = format!("{}/api/v0/models", base);
291        if let Ok(res) = self.client.get(&url).send().await {
292            if res.status().is_success() {
293                let body: Value = res.json().await.map_err(|e| e.to_string())?;
294                if let Some(data) = body["data"].as_array() {
295                    for m in data {
296                        let m_type = m["type"].as_str().unwrap_or_default();
297                        if (m_type == "chat" || m_type == "vlm" || m_type == "llm")
298                            && m["state"].as_str() == Some("loaded")
299                        {
300                            return Ok(m["id"].as_str().unwrap_or_default().to_string());
301                        }
302                    }
303                }
304            }
305        }
306        let url_v1 = format!("{}/v1/models", base);
307        let resp_v1 = self
308            .client
309            .get(&url_v1)
310            .send()
311            .await
312            .map_err(|e| e.to_string())?;
313        let body_v1: Value = resp_v1.json().await.map_err(|e| e.to_string())?;
314        if let Some(data) = body_v1["data"].as_array() {
315            if let Some(first) = data.iter().find(|m| {
316                !m["id"]
317                    .as_str()
318                    .unwrap_or_default()
319                    .to_lowercase()
320                    .contains("embed")
321            }) {
322                return Ok(first["id"].as_str().unwrap_or_default().to_string());
323            }
324        }
325        Ok(String::new())
326    }
327
328    async fn detect_context_length(&self) -> usize {
329        let base = self.base_url.trim_end_matches('/').trim_end_matches("/v1");
330        let url = format!("{}/api/v0/models", base);
331        if let Ok(res) = self.client.get(&url).send().await {
332            if res.status().is_success() {
333                let body: Value = res.json().await.unwrap_or_default();
334                if let Some(data) = body["data"].as_array() {
335                    for m in data {
336                        let m_type = m["type"].as_str().unwrap_or_default();
337                        if (m_type == "chat" || m_type == "vlm" || m_type == "llm")
338                            && m["state"].as_str() == Some("loaded")
339                        {
340                            // Try multiple possible field names and nested locations
341                            let fields = [
342                                "loaded_context_length",
343                                "context_length",
344                                "max_context_length",
345                                "contextLength",
346                            ];
347
348                            // Check top-level first
349                            for field in fields {
350                                if let Some(val) = m.get(field) {
351                                    if let Some(len) = val.as_u64() {
352                                        return len as usize;
353                                    }
354                                    if let Some(s) = val.as_str() {
355                                        if let Ok(len) = s.parse::<usize>() {
356                                            return len;
357                                        }
358                                    }
359                                }
360                            }
361
362                            // Check "stats" object
363                            if let Some(stats) = m.get("stats") {
364                                for field in fields {
365                                    if let Some(val) = stats.get(field) {
366                                        if let Some(len) = val.as_u64() {
367                                            return len as usize;
368                                        }
369                                        if let Some(s) = val.as_str() {
370                                            if let Ok(len) = s.parse::<usize>() {
371                                                return len;
372                                            }
373                                        }
374                                    }
375                                }
376                            }
377
378                            // Check "config" object
379                            if let Some(config) = m.get("config") {
380                                for field in fields {
381                                    if let Some(val) = config.get(field) {
382                                        if let Some(len) = val.as_u64() {
383                                            return len as usize;
384                                        }
385                                        if let Some(s) = val.as_str() {
386                                            if let Ok(len) = s.parse::<usize>() {
387                                                return len;
388                                            }
389                                        }
390                                    }
391                                }
392                            }
393                        }
394                    }
395                }
396            }
397        }
398        0
399    }
400
401    async fn load_model(&self, model_id: &str) -> Result<(), String> {
402        self.load_model_with_context(model_id, None).await
403    }
404
405    async fn load_model_with_context(
406        &self,
407        model_id: &str,
408        context_length: Option<usize>,
409    ) -> Result<(), String> {
410        let mut payload = serde_json::json!({ "model": model_id });
411        if let Some(ctx) = context_length {
412            payload["context_length"] = serde_json::json!(ctx);
413        }
414
415        let load_url = format!("{}/api/v1/models/load", self.base_url);
416        if let Ok(res) = self.client.post(&load_url).json(&payload).send().await {
417            if res.status().is_success() {
418                return Ok(());
419            }
420            let body = res.text().await.unwrap_or_default();
421            let body_note = truncate_provider_error_body(&body);
422            if !body_note.is_empty() {
423                return Err(format!("Model load failed: {}", body_note));
424            }
425        }
426
427        if context_length.is_none()
428            && self.lms.binary_path.is_some()
429            && self.lms.load_model(model_id).is_ok()
430        {
431            return Ok(());
432        }
433
434        let payload = serde_json::json!({
435            "model": model_id,
436            "messages": [{"role": "system", "content": "System boot"}],
437            "max_tokens": 1,
438            "stream": false
439        });
440        match self.client.post(&self.api_url).json(&payload).send().await {
441            Ok(res) if res.status().is_success() => Ok(()),
442            Ok(res) => Err(format!("Model load failed: HTTP {}", res.status())),
443            Err(e) => Err(format!("Model load failed: {}", e)),
444        }
445    }
446
447    async fn load_embedding_model(&self, model_id: &str) -> Result<(), String> {
448        self.load_model(model_id).await
449    }
450
451    async fn list_models(
452        &self,
453        kind: ProviderModelKind,
454        loaded_only: bool,
455    ) -> Result<Vec<String>, String> {
456        let mut models = Vec::new();
457
458        if loaded_only {
459            let url = format!("{}/api/v0/models", self.base_url);
460            if let Ok(res) = self.client.get(&url).send().await {
461                if res.status().is_success() {
462                    let body: Value = res.json().await.map_err(|e| e.to_string())?;
463                    if let Some(data) = body["data"].as_array() {
464                        for model in data {
465                            if model["state"].as_str() != Some("loaded") {
466                                continue;
467                            }
468                            let raw_type = model["type"].as_str().unwrap_or_default();
469                            if !matches_lms_model_kind(kind, raw_type) {
470                                continue;
471                            }
472                            if let Some(id) = model["id"].as_str() {
473                                push_unique_model(&mut models, id);
474                            }
475                        }
476                    }
477                }
478            }
479
480            if models.is_empty()
481                && self.lms.binary_path.is_some()
482                && kind != ProviderModelKind::Embed
483            {
484                if let Ok(cli_models) = self.lms.list_loaded_models() {
485                    for model in cli_models {
486                        push_unique_model(&mut models, &model);
487                    }
488                }
489            }
490            return Ok(models);
491        }
492
493        let url = format!("{}/api/v1/models", self.base_url);
494        if let Ok(res) = self.client.get(&url).send().await {
495            if res.status().is_success() {
496                let body: Value = res.json().await.map_err(|e| e.to_string())?;
497                if let Some(data) = body["data"].as_array() {
498                    for model in data {
499                        let raw_type = model["type"].as_str().unwrap_or_default();
500                        if !matches_lms_model_kind(kind, raw_type) {
501                            continue;
502                        }
503                        if let Some(id) = model["id"].as_str() {
504                            push_unique_model(&mut models, id);
505                        }
506                    }
507                }
508            }
509        }
510
511        if models.is_empty() && self.lms.binary_path.is_some() && kind != ProviderModelKind::Embed {
512            if let Ok(cli_models) = self.lms.list_models() {
513                for model in cli_models {
514                    push_unique_model(&mut models, &model);
515                }
516            }
517        }
518
519        Ok(models)
520    }
521
522    async fn unload_model(&self, model_id: Option<&str>, all: bool) -> Result<String, String> {
523        if all {
524            let loaded = self.list_models(ProviderModelKind::Any, true).await?;
525            if loaded.is_empty() {
526                return Ok("No LM Studio models are currently loaded.".to_string());
527            }
528
529            if self.lms.binary_path.is_some() && self.lms.unload_all_models().is_ok() {
530                return Ok(format!("Unloaded {} LM Studio model(s).", loaded.len()));
531            }
532
533            let unload_url = format!("{}/api/v1/models/unload", self.base_url);
534            let mut unloaded = 0usize;
535            let mut failures = Vec::new();
536            for instance_id in loaded {
537                match self
538                    .client
539                    .post(&unload_url)
540                    .json(&serde_json::json!({ "instance_id": instance_id }))
541                    .send()
542                    .await
543                {
544                    Ok(res) if res.status().is_success() => unloaded += 1,
545                    Ok(res) => failures.push(format!("{} ({})", instance_id, res.status())),
546                    Err(e) => failures.push(format!("{} ({})", instance_id, e)),
547                }
548            }
549            if failures.is_empty() {
550                return Ok(format!("Unloaded {} LM Studio model(s).", unloaded));
551            }
552            return Err(format!(
553                "Unloaded {} LM Studio model(s), but some unloads failed: {}",
554                unloaded,
555                failures.join(", ")
556            ));
557        }
558
559        let target = model_id
560            .map(str::trim)
561            .filter(|value| !value.is_empty())
562            .ok_or_else(|| "Missing model ID to unload.".to_string())?;
563
564        let unload_url = format!("{}/api/v1/models/unload", self.base_url);
565        match self
566            .client
567            .post(&unload_url)
568            .json(&serde_json::json!({ "instance_id": target }))
569            .send()
570            .await
571        {
572            Ok(res) if res.status().is_success() => {
573                Ok(format!("Unloaded LM Studio model `{}`.", target))
574            }
575            Ok(res) => {
576                let status = res.status();
577                let body = res.text().await.unwrap_or_default();
578                let body_note = truncate_provider_error_body(&body);
579                if self.lms.binary_path.is_some() && self.lms.unload_model(target).is_ok() {
580                    Ok(format!("Unloaded LM Studio model `{}`.", target))
581                } else if body_note.is_empty() {
582                    Err(format!("LM Studio unload failed: HTTP {}", status))
583                } else {
584                    Err(format!(
585                        "LM Studio unload failed: HTTP {} | {}",
586                        status, body_note
587                    ))
588                }
589            }
590            Err(err) => {
591                if self.lms.binary_path.is_some() && self.lms.unload_model(target).is_ok() {
592                    Ok(format!("Unloaded LM Studio model `{}`.", target))
593                } else {
594                    Err(format!("LM Studio unload failed: {}", err))
595                }
596            }
597        }
598    }
599
600    async fn unload_embedding_model(&self, model_id: Option<&str>) -> Result<String, String> {
601        self.unload_model(model_id, false).await
602    }
603
604    async fn prewarm(&self) -> Result<(), String> {
605        let payload = serde_json::json!({
606            "model": self.model,
607            "messages": [{"role": "system", "content": "Hematite BootSequence"}],
608            "max_tokens": 1,
609            "stream": false
610        });
611        let _ = self.client.post(&self.api_url).json(&payload).send().await;
612        Ok(())
613    }
614
615    async fn get_embedding_model(&self) -> Option<String> {
616        let url = format!("{}/api/v0/models", self.base_url);
617        if let Ok(res) = self.client.get(&url).send().await {
618            if let Ok(body) = res.json::<Value>().await {
619                if let Some(data) = body["data"].as_array() {
620                    return data
621                        .iter()
622                        .find(|m| {
623                            m["type"].as_str() == Some("embeddings")
624                                && m["state"].as_str() == Some("loaded")
625                        })
626                        .map(|m| m["id"].as_str().unwrap_or_default().to_string());
627                }
628            }
629        }
630        None
631    }
632
633    fn name(&self) -> &str {
634        "LM Studio"
635    }
636    fn current_model(&self) -> String {
637        self.model.clone()
638    }
639    fn context_length(&self) -> usize {
640        self.context_length
641    }
642    fn set_runtime_profile(&mut self, model: &str, context_length: usize) {
643        self.model = model.to_string();
644        self.context_length = context_length;
645    }
646}
647
648pub struct OllamaProvider {
649    pub client: Client,
650    pub base_url: String,
651    pub model: String,
652    pub context_length: usize,
653    pub embed_model: std::sync::Arc<std::sync::RwLock<Option<String>>>,
654    pub ollama: crate::agent::ollama::OllamaHarness,
655}
656
657#[async_trait]
658impl ModelProvider for OllamaProvider {
659    async fn call_with_tools(
660        &self,
661        messages: &[ChatMessage],
662        tools: &[ToolDefinition],
663        model_override: Option<&str>,
664    ) -> Result<ProviderResponse, String> {
665        let model = model_override.unwrap_or(&self.model).to_string();
666        let url = format!("{}/api/chat", self.base_url);
667        let request = serde_json::json!({
668            "model": model, "messages": messages, "stream": false,
669            "tools": if tools.is_empty() { None } else { Some(tools) },
670        });
671        let res = self
672            .client
673            .post(&url)
674            .json(&request)
675            .send()
676            .await
677            .map_err(|e| e.to_string())?;
678        if !res.status().is_success() {
679            return Err(format!("Ollama error: {}", res.status()));
680        }
681        let body: Value = res.json().await.map_err(|e| e.to_string())?;
682        let message = &body["message"];
683        let content = message["content"].as_str().map(|s| s.to_string());
684        let tool_calls = if let Some(calls) = message["tool_calls"].as_array() {
685            let mut mapped = Vec::new();
686            for (i, c) in calls.iter().enumerate() {
687                mapped.push(ToolCallResponse {
688                    id: format!("call_{}", i),
689                    call_type: "function".to_string(),
690                    function: ToolCallFn {
691                        name: c["function"]["name"]
692                            .as_str()
693                            .unwrap_or_default()
694                            .to_string(),
695                        arguments: c["function"]["arguments"].clone(),
696                    },
697                    index: Some(i as i32),
698                });
699            }
700            Some(mapped)
701        } else {
702            None
703        };
704        let usage = TokenUsage {
705            prompt_tokens: body["prompt_eval_count"].as_u64().unwrap_or(0) as usize,
706            completion_tokens: body["eval_count"].as_u64().unwrap_or(0) as usize,
707            ..Default::default()
708        };
709        Ok(ProviderResponse {
710            content,
711            tool_calls,
712            usage,
713            finish_reason: Some("stop".to_string()),
714        })
715    }
716
717    async fn stream(
718        &self,
719        messages: &[ChatMessage],
720        tx: mpsc::Sender<InferenceEvent>,
721    ) -> Result<(), Box<dyn std::error::Error>> {
722        let url = format!("{}/api/chat", self.base_url);
723        let request =
724            serde_json::json!({ "model": self.model, "messages": messages, "stream": true });
725        let res = self.client.post(&url).json(&request).send().await?;
726        use futures::StreamExt;
727        let mut stream = res.bytes_stream();
728        while let Some(chunk) = stream.next().await {
729            let chunk = chunk?;
730            if let Ok(v) = serde_json::from_slice::<Value>(&chunk) {
731                if let Some(delta) = v["message"]["content"].as_str() {
732                    let _ = tx.send(InferenceEvent::Token(delta.to_string())).await;
733                }
734                if v["done"].as_bool().unwrap_or(false) {
735                    break;
736                }
737            }
738        }
739        let _ = tx.send(InferenceEvent::Done).await;
740        Ok(())
741    }
742
743    async fn health_check(&self) -> bool {
744        self.ollama.is_reachable().await
745    }
746    async fn detect_model(&self) -> Result<String, String> {
747        let running_url = format!("{}/api/ps", self.base_url);
748        if let Ok(resp) = self.client.get(&running_url).send().await {
749            let body: Value = resp.json().await.map_err(|e| e.to_string())?;
750            if let Some(models) = body["models"].as_array() {
751                if let Some(first) = models.first() {
752                    let name = first["name"]
753                        .as_str()
754                        .or_else(|| first["model"].as_str())
755                        .unwrap_or_default();
756                    return Ok(name.to_string());
757                }
758                return Ok(String::new());
759            }
760        }
761
762        if !self.model.trim().is_empty() {
763            return Ok(self.model.clone());
764        }
765
766        let url = format!("{}/api/tags", self.base_url);
767        let resp = self
768            .client
769            .get(&url)
770            .send()
771            .await
772            .map_err(|e| e.to_string())?;
773        let body: Value = resp.json().await.map_err(|e| e.to_string())?;
774        if let Some(models) = body["models"].as_array() {
775            if let Some(first) = models.first() {
776                return Ok(first["name"].as_str().unwrap_or_default().to_string());
777            }
778        }
779        Ok(String::new())
780    }
781    async fn detect_context_length(&self) -> usize {
782        let running_url = format!("{}/api/ps", self.base_url);
783        if let Ok(resp) = self.client.get(&running_url).send().await {
784            if let Ok(body) = resp.json::<Value>().await {
785                if let Some(models) = body["models"].as_array() {
786                    if let Some(first) = models.first() {
787                        if let Some(context_length) = first["context_length"].as_u64() {
788                            return context_length as usize;
789                        }
790                    }
791                }
792            }
793        }
794        self.context_length
795    }
796    async fn load_model(&self, _model_id: &str) -> Result<(), String> {
797        self.load_model_with_context(_model_id, None).await
798    }
799    async fn load_model_with_context(
800        &self,
801        model_id: &str,
802        context_length: Option<usize>,
803    ) -> Result<(), String> {
804        if !self.ollama.has_model(model_id).await? {
805            return Err(format!(
806                "Ollama model `{}` is not pulled locally. Run `ollama pull {}` first.",
807                model_id, model_id
808            ));
809        }
810        let url = format!("{}/api/generate", self.base_url);
811        let request = serde_json::json!({
812            "model": model_id,
813            "prompt": "Hematite runtime warmup",
814            "stream": false,
815            "keep_alive": "30m",
816            "options": {
817                "num_ctx": context_length.unwrap_or(self.context_length.max(4096))
818            }
819        });
820        let res = self
821            .client
822            .post(&url)
823            .json(&request)
824            .send()
825            .await
826            .map_err(|e| e.to_string())?;
827        let status = res.status();
828        if status.is_success() {
829            Ok(())
830        } else {
831            let body = res.text().await.unwrap_or_default();
832            let body_note = truncate_provider_error_body(&body);
833            if body_note.is_empty() {
834                Err(format!("Ollama load failed: HTTP {}", status))
835            } else {
836                Err(format!(
837                    "Ollama load failed: HTTP {} | {}",
838                    status, body_note
839                ))
840            }
841        }
842    }
843    async fn load_embedding_model(&self, model_id: &str) -> Result<(), String> {
844        if !self.ollama.has_model(model_id).await? {
845            return Err(format!(
846                "Ollama embedding model `{}` is not pulled locally. Run `ollama pull {}` first.",
847                model_id, model_id
848            ));
849        }
850        let url = format!("{}/api/embed", self.base_url);
851        let request = serde_json::json!({
852            "model": model_id,
853            "input": "search_document: Hematite semantic search warmup",
854            "keep_alive": "30m"
855        });
856        let res = self
857            .client
858            .post(&url)
859            .json(&request)
860            .send()
861            .await
862            .map_err(|e| e.to_string())?;
863        let status = res.status();
864        if !status.is_success() {
865            let body = res.text().await.unwrap_or_default();
866            let body_note = truncate_provider_error_body(&body);
867            return if body_note.is_empty() {
868                Err(format!("Ollama embed load failed: HTTP {}", status))
869            } else {
870                Err(format!(
871                    "Ollama embed load failed: HTTP {} | {}",
872                    status, body_note
873                ))
874            };
875        }
876        if let Ok(mut guard) = self.embed_model.write() {
877            *guard = Some(model_id.to_string());
878        }
879        Ok(())
880    }
881    async fn list_models(
882        &self,
883        kind: ProviderModelKind,
884        loaded_only: bool,
885    ) -> Result<Vec<String>, String> {
886        let url = if loaded_only {
887            format!("{}/api/ps", self.base_url)
888        } else {
889            format!("{}/api/tags", self.base_url)
890        };
891        let resp = self
892            .client
893            .get(&url)
894            .send()
895            .await
896            .map_err(|e| e.to_string())?;
897        let body: Value = resp.json().await.map_err(|e| e.to_string())?;
898        let mut models = Vec::new();
899        if let Some(entries) = body["models"].as_array() {
900            for entry in entries {
901                let name = entry["name"]
902                    .as_str()
903                    .or_else(|| entry["model"].as_str())
904                    .unwrap_or_default();
905                if kind == ProviderModelKind::Embed && !looks_like_embedding_model_name(name) {
906                    continue;
907                }
908                if kind == ProviderModelKind::Coding && looks_like_embedding_model_name(name) {
909                    continue;
910                }
911                push_unique_model(&mut models, name);
912            }
913        }
914        if loaded_only && kind == ProviderModelKind::Embed {
915            if let Ok(guard) = self.embed_model.read() {
916                if let Some(model) = guard.as_deref() {
917                    push_unique_model(&mut models, model);
918                }
919            }
920        }
921        Ok(models)
922    }
923    async fn unload_model(&self, model_id: Option<&str>, all: bool) -> Result<String, String> {
924        let targets = if all {
925            self.list_models(ProviderModelKind::Coding, true).await?
926        } else {
927            vec![model_id
928                .map(str::trim)
929                .filter(|value| !value.is_empty())
930                .ok_or_else(|| "Missing model ID to unload.".to_string())?
931                .to_string()]
932        };
933
934        if targets.is_empty() {
935            return Ok("No Ollama models are currently loaded.".to_string());
936        }
937
938        let url = format!("{}/api/generate", self.base_url);
939        let mut unloaded = 0usize;
940        let mut failures = Vec::new();
941        for target in targets {
942            let request = serde_json::json!({
943                "model": target,
944                "prompt": "",
945                "stream": false,
946                "keep_alive": 0
947            });
948            match self.client.post(&url).json(&request).send().await {
949                Ok(res) if res.status().is_success() => unloaded += 1,
950                Ok(res) => failures.push(format!("{} ({})", target, res.status())),
951                Err(e) => failures.push(format!("{} ({})", target, e)),
952            }
953        }
954
955        if failures.is_empty() {
956            return Ok(if all {
957                format!("Unloaded {} Ollama model(s).", unloaded)
958            } else {
959                format!("Unloaded Ollama model `{}`.", model_id.unwrap_or_default())
960            });
961        }
962
963        Err(format!(
964            "Unloaded {} Ollama model(s), but some unloads failed: {}",
965            unloaded,
966            failures.join(", ")
967        ))
968    }
969    async fn unload_embedding_model(&self, model_id: Option<&str>) -> Result<String, String> {
970        let target = match model_id {
971            Some(explicit) if !explicit.trim().is_empty() => explicit.trim().to_string(),
972            _ => self
973                .get_embedding_model()
974                .await
975                .ok_or_else(|| "No Ollama embedding model is currently loaded.".to_string())?,
976        };
977        let url = format!("{}/api/embed", self.base_url);
978        let request = serde_json::json!({
979            "model": target,
980            "input": "search_document: Hematite semantic search warmup",
981            "keep_alive": 0
982        });
983        let res = self
984            .client
985            .post(&url)
986            .json(&request)
987            .send()
988            .await
989            .map_err(|e| e.to_string())?;
990        if res.status().is_success() {
991            if let Ok(mut guard) = self.embed_model.write() {
992                if guard.as_deref() == Some(target.as_str()) {
993                    *guard = None;
994                }
995            }
996            Ok(format!("Unloaded Ollama embedding model `{}`.", target))
997        } else {
998            let status = res.status();
999            let body = res.text().await.unwrap_or_default();
1000            let body_note = truncate_provider_error_body(&body);
1001            if body_note.is_empty() {
1002                Err(format!("Ollama embed unload failed: HTTP {}", status))
1003            } else {
1004                Err(format!(
1005                    "Ollama embed unload failed: HTTP {} | {}",
1006                    status, body_note
1007                ))
1008            }
1009        }
1010    }
1011    async fn prewarm(&self) -> Result<(), String> {
1012        Ok(())
1013    }
1014    async fn get_embedding_model(&self) -> Option<String> {
1015        if let Ok(guard) = self.embed_model.read() {
1016            if let Some(model) = guard.as_ref() {
1017                return Some(model.clone());
1018            }
1019        }
1020
1021        let url = format!("{}/api/ps", self.base_url);
1022        if let Ok(res) = self.client.get(&url).send().await {
1023            if let Ok(body) = res.json::<Value>().await {
1024                if let Some(entries) = body["models"].as_array() {
1025                    for entry in entries {
1026                        let name = entry["name"]
1027                            .as_str()
1028                            .or_else(|| entry["model"].as_str())
1029                            .unwrap_or_default();
1030                        if looks_like_embedding_model_name(name) {
1031                            return Some(name.to_string());
1032                        }
1033                    }
1034                }
1035            }
1036        }
1037        None
1038    }
1039
1040    fn name(&self) -> &str {
1041        "Ollama"
1042    }
1043    fn current_model(&self) -> String {
1044        self.model.clone()
1045    }
1046    fn context_length(&self) -> usize {
1047        self.context_length
1048    }
1049    fn set_runtime_profile(&mut self, model: &str, context_length: usize) {
1050        self.model = model.to_string();
1051        self.context_length = context_length;
1052    }
1053}
1054
1055#[cfg(test)]
1056mod tests {
1057    use super::{
1058        lms_messages_payload, looks_like_embedding_model_name, matches_lms_model_kind,
1059        ProviderModelKind,
1060    };
1061    use crate::agent::types::{ChatMessage, ToolCallFn, ToolCallResponse};
1062    use serde_json::json;
1063
1064    #[test]
1065    fn lms_payload_stringifies_assistant_tool_arguments() {
1066        let messages = vec![ChatMessage::assistant_tool_calls(
1067            "",
1068            vec![ToolCallResponse {
1069                id: "call_1".to_string(),
1070                call_type: "function".to_string(),
1071                function: ToolCallFn {
1072                    name: "read_file".to_string(),
1073                    arguments: json!({"path":"index.html"}),
1074                },
1075                index: None,
1076            }],
1077        )];
1078
1079        let payload = lms_messages_payload(&messages);
1080        let args = &payload[0]["tool_calls"][0]["function"]["arguments"];
1081        assert!(args.is_string());
1082        assert_eq!(
1083            args.as_str().unwrap_or_default(),
1084            "{\"path\":\"index.html\"}"
1085        );
1086    }
1087
1088    #[test]
1089    fn lms_model_kind_matching_distinguishes_embedding_models() {
1090        assert!(matches_lms_model_kind(ProviderModelKind::Coding, "chat"));
1091        assert!(matches_lms_model_kind(
1092            ProviderModelKind::Embed,
1093            "embeddings"
1094        ));
1095        assert!(!matches_lms_model_kind(
1096            ProviderModelKind::Coding,
1097            "embeddings"
1098        ));
1099        assert!(!matches_lms_model_kind(ProviderModelKind::Embed, "chat"));
1100    }
1101
1102    #[test]
1103    fn embedding_name_heuristic_catches_common_ollama_embed_models() {
1104        assert!(looks_like_embedding_model_name("embeddinggemma"));
1105        assert!(looks_like_embedding_model_name("qwen3-embedding"));
1106        assert!(looks_like_embedding_model_name("all-minilm"));
1107        assert!(!looks_like_embedding_model_name("qwen3.5:latest"));
1108    }
1109}