Skip to main content

devsper_providers/
ollama.rs

1use devsper_core::{LlmProvider, LlmRequest, LlmResponse, LlmRole, StopReason};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use tracing::debug;
7
8/// Ollama local model provider.
9pub struct OllamaProvider {
10    client: Client,
11    base_url: String,
12}
13
14impl OllamaProvider {
15    pub fn new() -> Self {
16        Self {
17            client: Client::new(),
18            base_url: "http://localhost:11434".to_string(),
19        }
20    }
21
22    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
23        self.base_url = url.into();
24        self
25    }
26}
27
28impl Default for OllamaProvider {
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34#[derive(Serialize)]
35struct OllamaRequest<'a> {
36    model: &'a str,
37    prompt: String,
38    stream: bool,
39}
40
41#[derive(Deserialize)]
42struct OllamaResponse {
43    response: String,
44    #[serde(default)]
45    prompt_eval_count: u32,
46    #[serde(default)]
47    eval_count: u32,
48}
49
50#[async_trait]
51impl LlmProvider for OllamaProvider {
52    async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
53        use tracing::Instrument;
54
55        let span = tracing::info_span!(
56            "gen_ai.chat",
57            "gen_ai.system" = self.name(),
58            "gen_ai.operation.name" = "chat",
59            "gen_ai.request.model" = req.model.as_str(),
60            "gen_ai.request.max_tokens" = req.max_tokens,
61            "gen_ai.response.model" = tracing::field::Empty,
62            "gen_ai.usage.input_tokens" = tracing::field::Empty,
63            "gen_ai.usage.output_tokens" = tracing::field::Empty,
64        );
65
66        // Flatten messages into a single prompt for Ollama
67        let prompt = req
68            .messages
69            .iter()
70            .map(|m| match m.role {
71                LlmRole::System => format!("System: {}\n", m.content),
72                LlmRole::User | LlmRole::Tool => format!("User: {}\n", m.content),
73                LlmRole::Assistant => format!("Assistant: {}\n", m.content),
74            })
75            .collect::<String>();
76
77        // Strip "ollama:" prefix if present
78        let model = req.model.strip_prefix("ollama:").unwrap_or(&req.model);
79
80        debug!(model = %model, "Ollama request");
81
82        let body = OllamaRequest {
83            model,
84            prompt,
85            stream: false,
86        };
87
88        let model_name = req.model.clone();
89        let result = async {
90            let resp = self
91                .client
92                .post(format!("{}/api/generate", self.base_url))
93                .json(&body)
94                .send()
95                .await?;
96
97            if !resp.status().is_success() {
98                let status = resp.status();
99                let text = resp.text().await.unwrap_or_default();
100                return Err(anyhow!("Ollama error {status}: {text}"));
101            }
102
103            let data: OllamaResponse = resp.json().await?;
104
105            Ok(LlmResponse {
106                content: data.response,
107                tool_calls: vec![],
108                input_tokens: data.prompt_eval_count,
109                output_tokens: data.eval_count,
110                model: model_name,
111                stop_reason: StopReason::EndTurn,
112            })
113        }
114        .instrument(span.clone())
115        .await;
116
117        if let Ok(ref resp) = result {
118            span.record("gen_ai.response.model", resp.model.as_str());
119            span.record("gen_ai.usage.input_tokens", resp.input_tokens);
120            span.record("gen_ai.usage.output_tokens", resp.output_tokens);
121        }
122        result
123    }
124
125    fn name(&self) -> &str {
126        "ollama"
127    }
128
129    fn supports_model(&self, model: &str) -> bool {
130        model.starts_with("ollama:")
131    }
132}