llmkit_ollama/
provider.rs1use std::time::Instant;
4
5use async_trait::async_trait;
6use llmkit_core::{
7 ChatRequest, ChatResponse, ChatStream, EmbedRequest, EmbedResponse, LlmError, LlmProvider,
8 LlmResult, TokenUsage,
9};
10
11use crate::types::{ChatResponseBody, EmbeddingsRequestBody, EmbeddingsResponseBody};
12use crate::{chat, stream};
13
14const DEFAULT_BASE_URL: &str = "http://localhost:11434";
15const DEFAULT_MODEL: &str = "llama3.1";
16
17#[derive(Clone)]
19pub struct OllamaProvider {
20 http: reqwest::Client,
21 base_url: String,
22 model: String,
23}
24
25impl OllamaProvider {
26 pub fn new() -> Self {
28 Self {
29 http: reqwest::Client::new(),
30 base_url: DEFAULT_BASE_URL.to_string(),
31 model: DEFAULT_MODEL.to_string(),
32 }
33 }
34
35 pub fn from_env() -> Self {
37 let mut p = Self::new();
38 if let Ok(host) = std::env::var("OLLAMA_HOST") {
39 p.base_url = host;
40 }
41 p
42 }
43
44 pub fn model(mut self, model: impl Into<String>) -> Self {
46 self.model = model.into();
47 self
48 }
49
50 pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
52 self.base_url = base_url.into();
53 self
54 }
55
56 pub fn with_client(mut self, client: reqwest::Client) -> Self {
58 self.http = client;
59 self
60 }
61
62 fn resolved_model(&self, req: &ChatRequest) -> String {
63 req.model.clone().unwrap_or_else(|| self.model.clone())
64 }
65}
66
67impl Default for OllamaProvider {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73#[async_trait]
74impl LlmProvider for OllamaProvider {
75 async fn chat(&self, req: ChatRequest) -> LlmResult<ChatResponse> {
76 let model = self.resolved_model(&req);
77 let body = chat::build_request(&req, model, false);
78
79 let start = Instant::now();
80 let resp = self
81 .http
82 .post(format!("{}/api/chat", self.base_url))
83 .json(&body)
84 .send()
85 .await
86 .map_err(map_reqwest_err)?;
87
88 let resp = check_status(resp).await?;
89 let parsed: ChatResponseBody = resp.json().await.map_err(map_reqwest_err)?;
90 chat::map_response(parsed, start.elapsed().as_millis() as u64)
91 }
92
93 async fn chat_stream(&self, req: ChatRequest) -> LlmResult<ChatStream> {
94 let model = self.resolved_model(&req);
95 let body = chat::build_request(&req, model, true);
96
97 let resp = self
98 .http
99 .post(format!("{}/api/chat", self.base_url))
100 .json(&body)
101 .send()
102 .await
103 .map_err(map_reqwest_err)?;
104
105 let resp = check_status(resp).await?;
106 Ok(stream::parse(resp))
107 }
108
109 async fn embed(&self, req: EmbedRequest) -> LlmResult<EmbedResponse> {
110 let model = req.model.clone().unwrap_or_else(|| self.model.clone());
111 let body = EmbeddingsRequestBody { model, input: req.input };
112
113 let resp = self
114 .http
115 .post(format!("{}/api/embed", self.base_url))
116 .json(&body)
117 .send()
118 .await
119 .map_err(map_reqwest_err)?;
120
121 let resp = check_status(resp).await?;
122 let parsed: EmbeddingsResponseBody = resp.json().await.map_err(map_reqwest_err)?;
123
124 Ok(EmbedResponse {
125 provider: "ollama".into(),
126 model: parsed.model,
127 embeddings: parsed.embeddings,
128 usage: TokenUsage::new(parsed.prompt_eval_count.unwrap_or(0), 0),
129 })
130 }
131
132 fn name(&self) -> &'static str {
133 "ollama"
134 }
135
136 fn model(&self) -> &str {
137 &self.model
138 }
139}
140
141fn map_reqwest_err(e: reqwest::Error) -> LlmError {
142 if e.is_timeout() {
143 LlmError::Timeout
144 } else if e.is_connect() {
145 LlmError::Transport(format!("cannot reach Ollama server: {e}"))
146 } else if e.is_decode() {
147 LlmError::Serialization(e.to_string())
148 } else {
149 LlmError::Transport(e.to_string())
150 }
151}
152
153async fn check_status(resp: reqwest::Response) -> LlmResult<reqwest::Response> {
154 let status = resp.status();
155 if status.is_success() {
156 return Ok(resp);
157 }
158 let code = status.as_u16();
159 let message = resp.text().await.unwrap_or_default();
160 Err(match code {
161 404 => LlmError::InvalidRequest(format!("model not found or endpoint missing: {message}")),
162 400 => LlmError::InvalidRequest(message),
163 _ => LlmError::Provider { status: code, message },
164 })
165}