1use base64::engine::general_purpose::STANDARD as B64;
2use base64::Engine;
3use ollama_rs::Ollama;
4use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
5use url::Url;
6
7use crate::config::{AuthConfig, RuntimeConfig, RuntimeMode};
8use crate::error::{Result, RuntimeError};
9use crate::guard::ExecutionGuard;
10#[cfg(feature = "stream")]
11use crate::guard::GuardedStream;
12
13pub struct OllamaRuntime {
15 client: Ollama,
16 guard: ExecutionGuard,
17 auto_pull: bool,
18 mode: RuntimeMode,
19}
20
21impl OllamaRuntime {
22 pub async fn new(config: RuntimeConfig) -> Result<Self> {
23 config.validate()?;
24 let base = parse_and_normalize_base_url(&config.base_url)?;
25 let port = base.port().ok_or_else(|| {
26 RuntimeError::Other("internal error: base_url normalization did not set a port".into())
27 })?;
28
29 let reqwest_client = build_reqwest_client(&config)?;
30
31 let client = Ollama::new_with_client(base.as_str(), port, reqwest_client);
32
33 let guard = ExecutionGuard::new(config.max_concurrent, config.timeout, config.max_retries);
34
35 Ok(Self {
36 client,
37 guard,
38 auto_pull: config.auto_pull,
39 mode: config.mode,
40 })
41 }
42
43 pub fn client(&self) -> &Ollama {
44 &self.client
45 }
46
47 pub fn guard(&self) -> &ExecutionGuard {
48 &self.guard
49 }
50
51 pub fn auto_pull(&self) -> bool {
52 self.auto_pull
53 }
54
55 pub fn mode(&self) -> RuntimeMode {
56 self.mode
57 }
58
59 pub async fn ensure(&self, model: &str) -> Result<()> {
60 crate::model::ensure_local_model(self.auto_pull, self.mode, self.client(), model).await
61 }
62
63 pub async fn call<F, Fut, T>(&self, f: F) -> Result<T>
64 where
65 F: Fn(&Ollama) -> Fut,
66 Fut: std::future::Future<Output = ollama_rs::error::Result<T>>,
67 {
68 self.guard().run(|| f(self.client())).await
69 }
70
71 #[cfg(feature = "stream")]
72 pub async fn call_stream<F, Fut, S>(&self, f: F) -> Result<GuardedStream<S>>
73 where
74 F: Fn(&Ollama) -> Fut,
75 Fut: std::future::Future<Output = ollama_rs::error::Result<S>>,
76 S: tokio_stream::Stream + Unpin,
77 {
78 self.guard().run_stream(|| f(self.client())).await
79 }
80}
81
82fn parse_and_normalize_base_url(raw: &str) -> Result<Url> {
83 let mut url =
84 Url::parse(raw).map_err(|e| RuntimeError::Other(format!("invalid base_url: {e}")))?;
85
86 if url.host().is_none() {
87 return Err(RuntimeError::Other("base_url must include a host".into()));
88 }
89
90 if url.port().is_none() {
91 let port = match url.scheme() {
92 "http" => 11434u16,
93 "https" => 443u16,
94 other => {
95 return Err(RuntimeError::Other(format!(
96 "unsupported URL scheme: {other} (expected http or https)"
97 )));
98 }
99 };
100 url.set_port(Some(port)).map_err(|_| {
101 RuntimeError::Other("invalid base_url: could not apply default port".into())
102 })?;
103 }
104
105 let path = url.path();
106 if path.is_empty() || path == "/" {
107 url.set_path("/");
108 } else if !path.ends_with('/') {
109 url.set_path(&format!("{}/", path.trim_end_matches('/')));
110 }
111
112 Ok(url)
113}
114
115fn build_reqwest_client(config: &RuntimeConfig) -> Result<reqwest::Client> {
116 let mut headers = HeaderMap::new();
117 if let Some(auth) = &config.auth {
118 merge_auth_headers(&mut headers, auth)?;
119 }
120
121 reqwest::Client::builder()
122 .default_headers(headers)
123 .connect_timeout(config.connect_timeout)
124 .timeout(config.timeout)
125 .build()
126 .map_err(|e| RuntimeError::Other(format!("failed to build HTTP client: {e}")))
127}
128
129fn merge_auth_headers(headers: &mut HeaderMap, auth: &AuthConfig) -> Result<()> {
130 match auth {
131 AuthConfig::BearerToken(token) => {
132 let value = HeaderValue::try_from(format!("Bearer {token}")).map_err(header_err)?;
133 headers.insert(HeaderName::from_static("authorization"), value);
134 }
135 AuthConfig::Basic { username, password } => {
136 let creds = format!("{}:{}", username, password);
137 let b64 = B64.encode(creds.as_bytes());
138 let value = HeaderValue::try_from(format!("Basic {b64}")).map_err(header_err)?;
139 headers.insert(HeaderName::from_static("authorization"), value);
140 }
141 AuthConfig::CustomHeader { key, value } => {
142 let name = HeaderName::try_from(key.as_str()).map_err(header_err)?;
143 let val = HeaderValue::try_from(value.as_str()).map_err(header_err)?;
144 headers.insert(name, val);
145 }
146 }
147 Ok(())
148}
149
150fn header_err<E: std::fmt::Display>(e: E) -> RuntimeError {
151 RuntimeError::Other(format!("invalid auth header: {e}"))
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 #[test]
158 fn normalizes_empty_path_and_port() {
159 let u = parse_and_normalize_base_url("http://127.0.0.1").unwrap();
160 assert_eq!(u.as_str(), "http://127.0.0.1:11434/");
161 }
162
163 #[test]
164 fn normalizes_subpath_trailing_slash() {
165 let u = parse_and_normalize_base_url("http://proxy/v1/ollama").unwrap();
166 assert!(u.as_str().ends_with("/v1/ollama/"));
167 }
168}