1use base64::engine::general_purpose::STANDARD as B64;
2use base64::Engine;
3use ollama_rs::Ollama;
4use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
5use url::Url;
6
7use crate::config::{AuthConfig, RuntimeConfig};
8use crate::error::{Result, RuntimeError};
9use crate::guard::ExecutionGuard;
10#[cfg(feature = "stream")]
11use crate::guard::GuardedStream;
12use crate::model::ModelManager;
13
14pub struct OllamaRuntime {
16 client: Ollama,
17 manager: ModelManager,
18 guard: ExecutionGuard,
19}
20
21impl OllamaRuntime {
22 pub async fn new(config: RuntimeConfig) -> Result<Self> {
25 config.validate()?;
26 let base = parse_and_normalize_base_url(&config.base_url)?;
27 let port = base.port().ok_or_else(|| {
28 RuntimeError::Other("internal error: base_url normalization did not set a port".into())
29 })?;
30
31 let reqwest_client = build_reqwest_client(&config)?;
32
33 let client = Ollama::new_with_client(base.as_str(), port, reqwest_client);
34
35 let guard = ExecutionGuard::new(config.max_concurrent, config.timeout, config.max_retries);
36
37 let manager = ModelManager::new(client.clone(), config.auto_pull, config.mode);
38
39 Ok(Self {
40 client,
41 manager,
42 guard,
43 })
44 }
45
46 pub fn client(&self) -> &Ollama {
47 &self.client
48 }
49
50 pub fn guard(&self) -> &ExecutionGuard {
51 &self.guard
52 }
53
54 pub fn models(&self) -> &ModelManager {
55 &self.manager
56 }
57
58 pub async fn ensure_model(&self, model: &str) -> Result<()> {
60 self.models().ensure(model).await
61 }
62
63 pub async fn run<F, Fut, T>(&self, f: F) -> Result<T>
69 where
70 F: Fn() -> Fut,
71 Fut: std::future::Future<Output = ollama_rs::error::Result<T>>,
72 {
73 self.guard().run(f).await
74 }
75
76 #[cfg(feature = "stream")]
78 pub async fn run_stream<F, Fut, S>(&self, f: F) -> Result<GuardedStream<S>>
79 where
80 F: Fn() -> Fut,
81 Fut: std::future::Future<Output = ollama_rs::error::Result<S>>,
82 S: tokio_stream::Stream + Unpin,
83 {
84 self.guard().run_stream(f).await
85 }
86}
87
88fn parse_and_normalize_base_url(raw: &str) -> Result<Url> {
89 let mut url =
90 Url::parse(raw).map_err(|e| RuntimeError::Other(format!("invalid base_url: {e}")))?;
91
92 if url.host().is_none() {
93 return Err(RuntimeError::Other("base_url must include a host".into()));
94 }
95
96 if url.port().is_none() {
97 let port = match url.scheme() {
98 "http" => 11434u16,
99 "https" => 443u16,
100 other => {
101 return Err(RuntimeError::Other(format!(
102 "unsupported URL scheme: {other} (expected http or https)"
103 )));
104 }
105 };
106 url.set_port(Some(port)).map_err(|_| {
107 RuntimeError::Other("invalid base_url: could not apply default port".into())
108 })?;
109 }
110
111 let path = url.path();
112 if path.is_empty() || path == "/" {
113 url.set_path("/");
114 } else if !path.ends_with('/') {
115 url.set_path(&format!("{}/", path.trim_end_matches('/')));
116 }
117
118 Ok(url)
119}
120
121fn build_reqwest_client(config: &RuntimeConfig) -> Result<reqwest::Client> {
122 let mut headers = HeaderMap::new();
123 if let Some(auth) = &config.auth {
124 merge_auth_headers(&mut headers, auth)?;
125 }
126
127 reqwest::Client::builder()
128 .default_headers(headers)
129 .connect_timeout(config.connect_timeout)
130 .timeout(config.timeout)
131 .build()
132 .map_err(|e| RuntimeError::Other(format!("failed to build HTTP client: {e}")))
133}
134
135fn merge_auth_headers(headers: &mut HeaderMap, auth: &AuthConfig) -> Result<()> {
136 match auth {
137 AuthConfig::BearerToken(token) => {
138 let value = HeaderValue::try_from(format!("Bearer {token}")).map_err(header_err)?;
139 headers.insert(HeaderName::from_static("authorization"), value);
140 }
141 AuthConfig::Basic { username, password } => {
142 let creds = format!("{}:{}", username, password);
143 let b64 = B64.encode(creds.as_bytes());
144 let value = HeaderValue::try_from(format!("Basic {b64}")).map_err(header_err)?;
145 headers.insert(HeaderName::from_static("authorization"), value);
146 }
147 AuthConfig::CustomHeader { key, value } => {
148 let name = HeaderName::try_from(key.as_str()).map_err(header_err)?;
149 let val = HeaderValue::try_from(value.as_str()).map_err(header_err)?;
150 headers.insert(name, val);
151 }
152 }
153 Ok(())
154}
155
156fn header_err<E: std::fmt::Display>(e: E) -> RuntimeError {
157 RuntimeError::Other(format!("invalid auth header: {e}"))
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 #[test]
164 fn normalizes_empty_path_and_port() {
165 let u = parse_and_normalize_base_url("http://127.0.0.1").unwrap();
166 assert_eq!(u.as_str(), "http://127.0.0.1:11434/");
167 }
168
169 #[test]
170 fn normalizes_subpath_trailing_slash() {
171 let u = parse_and_normalize_base_url("http://proxy/v1/ollama").unwrap();
172 assert!(u.as_str().ends_with("/v1/ollama/"));
173 }
174}