1#[cfg(feature = "local-backend")]
2use std::io::{self, Read, Write};
3#[cfg(feature = "local-backend")]
4use std::net::{TcpStream, ToSocketAddrs};
5#[cfg(feature = "local-backend")]
6use std::time::Duration;
7
8use serde::Deserialize;
9
10#[cfg(feature = "local-backend")]
11const MAX_PROBE_RESPONSE_BYTES: usize = 1024;
12
13#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
14pub struct Backend {
15 pub name: String,
16 pub url: String,
17 pub probe: String,
18 #[serde(default)]
19 pub auth_token: String,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub struct BackendDefault {
24 pub name: &'static str,
25 pub url: &'static str,
26 pub probe: &'static str,
27 pub auth_token: &'static str,
28}
29
30impl BackendDefault {
31 pub fn to_backend(self) -> Backend {
32 Backend {
33 name: self.name.to_string(),
34 url: self.url.to_string(),
35 probe: self.probe.to_string(),
36 auth_token: self.auth_token.to_string(),
37 }
38 }
39}
40
41pub const DEFAULT_BACKENDS: &[BackendDefault] = &[
42 BackendDefault {
43 name: "lmstudio",
44 url: "http://localhost:1234",
45 probe: "/v1/models",
46 auth_token: "",
47 },
48 BackendDefault {
49 name: "ollama",
50 url: "http://localhost:11434",
51 probe: "/api/tags",
52 auth_token: "",
53 },
54];
55
56pub fn default_backends() -> Vec<Backend> {
57 DEFAULT_BACKENDS
58 .iter()
59 .copied()
60 .map(BackendDefault::to_backend)
61 .collect()
62}
63
64pub fn backend_api_base(backend: &Backend) -> String {
65 format!("{}/v1", backend.url.trim_end_matches('/'))
66}
67
68#[cfg(feature = "local-backend")]
70pub fn detect_backend(backends: &[Backend], timeout_ms: u64) -> Option<Backend> {
71 for backend in backends {
72 if validate_backend(backend, timeout_ms) {
73 return Some(backend.clone());
74 }
75 }
76 None
77}
78
79#[cfg(feature = "local-backend")]
81pub fn validate_backend(backend: &Backend, timeout_ms: u64) -> bool {
82 let timeout = Duration::from_millis(timeout_ms);
83 let url = backend_probe_url(backend);
84 let Some(target) = HttpProbeTarget::parse(&url) else {
85 log::trace!(
86 "local backend probe `{}` at {} failed: unsupported HTTP URL",
87 backend.name,
88 url
89 );
90 return false;
91 };
92
93 match send_probe_request(&target, backend.auth_token.trim(), timeout) {
94 Ok(status) if (200..300).contains(&status) => true,
95 Ok(status) => {
96 log::trace!(
97 "local backend probe `{}` at {} returned HTTP {}",
98 backend.name,
99 url,
100 status
101 );
102 false
103 }
104 Err(error) => {
105 log::trace!(
106 "local backend probe `{}` at {} failed: {}",
107 backend.name,
108 url,
109 error
110 );
111 false
112 }
113 }
114}
115
116#[cfg(feature = "local-backend")]
117#[derive(Debug, Clone, PartialEq, Eq)]
118struct HttpProbeTarget {
119 host: String,
120 port: u16,
121 path: String,
122}
123
124#[cfg(feature = "local-backend")]
125impl HttpProbeTarget {
126 fn parse(url: &str) -> Option<Self> {
127 let rest = url.strip_prefix("http://")?;
128 let (authority, path) = rest.split_once('/').unwrap_or((rest, ""));
129 let (host, port) = parse_http_authority(authority)?;
130 Some(Self {
131 host,
132 port,
133 path: format!("/{path}"),
134 })
135 }
136
137 fn socket_addr(&self) -> String {
138 if self.host.contains(':') {
139 format!("[{}]:{}", self.host, self.port)
140 } else {
141 format!("{}:{}", self.host, self.port)
142 }
143 }
144
145 fn host_header(&self) -> String {
146 if self.port == 80 {
147 self.host.clone()
148 } else if self.host.contains(':') {
149 format!("[{}]:{}", self.host, self.port)
150 } else {
151 format!("{}:{}", self.host, self.port)
152 }
153 }
154}
155
156#[cfg(feature = "local-backend")]
157fn parse_http_authority(authority: &str) -> Option<(String, u16)> {
158 if authority.is_empty() || authority.contains('@') {
159 return None;
160 }
161 if let Some(rest) = authority.strip_prefix('[') {
162 let (host, suffix) = rest.split_once(']')?;
163 if host.is_empty() {
164 return None;
165 }
166 let port = if suffix.is_empty() {
167 80
168 } else {
169 suffix.strip_prefix(':')?.parse().ok()?
170 };
171 return Some((host.to_string(), port));
172 }
173 if authority.contains('[') || authority.contains(']') {
174 return None;
175 }
176 match authority.rsplit_once(':') {
177 Some((host, port)) if !host.contains(':') && !host.is_empty() => {
178 Some((host.to_string(), port.parse().ok()?))
179 }
180 Some(_) => None,
181 None => Some((authority.to_string(), 80)),
182 }
183}
184
185#[cfg(feature = "local-backend")]
186fn send_probe_request(
187 target: &HttpProbeTarget,
188 auth_token: &str,
189 timeout: Duration,
190) -> io::Result<u16> {
191 let addr = target
192 .socket_addr()
193 .to_socket_addrs()?
194 .next()
195 .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "no resolved address"))?;
196 let mut stream = TcpStream::connect_timeout(&addr, timeout)?;
197 stream.set_read_timeout(Some(timeout))?;
198 stream.set_write_timeout(Some(timeout))?;
199
200 let mut request = format!(
201 "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: gobby-core\r\nConnection: close\r\n",
202 target.path,
203 target.host_header()
204 );
205 if !auth_token.is_empty() {
206 request.push_str("Authorization: Bearer ");
207 request.push_str(auth_token);
208 request.push_str("\r\n");
209 }
210 request.push_str("\r\n");
211 stream.write_all(request.as_bytes())?;
212
213 let mut response = Vec::new();
214 let mut chunk = [0_u8; 128];
215 while response.len() < MAX_PROBE_RESPONSE_BYTES {
216 let read = match stream.read(&mut chunk) {
217 Ok(read) => read,
218 Err(error)
219 if matches!(
220 error.kind(),
221 io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut
222 ) && !response.is_empty() =>
223 {
224 break;
225 }
226 Err(error) => return Err(error),
227 };
228 if read == 0 {
229 break;
230 }
231 response.extend_from_slice(&chunk[..read]);
232 if response.contains(&b'\n') {
233 break;
234 }
235 }
236 parse_http_status(&response)
237}
238
239#[cfg(feature = "local-backend")]
240fn parse_http_status(response: &[u8]) -> io::Result<u16> {
241 let response = String::from_utf8_lossy(response);
242 response
243 .lines()
244 .next()
245 .and_then(|line| line.split_whitespace().nth(1))
246 .and_then(|status| status.parse().ok())
247 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing HTTP status"))
248}
249
250#[cfg(feature = "local-backend")]
251fn backend_probe_url(backend: &Backend) -> String {
252 let base = backend.url.trim_end_matches('/');
253 let probe = backend.probe.trim_start_matches('/');
254 if probe.is_empty() {
255 base.to_string()
256 } else {
257 format!("{base}/{probe}")
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 #[test]
266 fn default_local_backends_do_not_send_auth_tokens() {
267 assert!(
268 default_backends()
269 .iter()
270 .all(|backend| backend.auth_token.is_empty())
271 );
272 }
273
274 #[cfg(feature = "local-backend")]
275 mod http {
276 use super::*;
277 use std::io::{Read, Write};
278 use std::net::TcpListener;
279 use std::thread;
280
281 fn reachable_backend() -> (Backend, thread::JoinHandle<String>) {
282 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
283 let addr = listener.local_addr().unwrap();
284 let handle = thread::spawn(move || {
285 if let Ok((mut stream, _)) = listener.accept() {
286 let mut buffer = [0_u8; 1024];
287 let read = stream.read(&mut buffer).unwrap_or(0);
288 let _ = stream.write_all(
289 b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\n{}",
290 );
291 return String::from_utf8_lossy(&buffer[..read]).to_string();
292 }
293 String::new()
294 });
295
296 (
297 Backend {
298 name: "reachable".into(),
299 url: format!("http://{}", addr),
300 probe: "/v1/models".into(),
301 auth_token: "token".into(),
302 },
303 handle,
304 )
305 }
306
307 fn unreachable_backend() -> Backend {
308 Backend {
309 name: "unreachable".into(),
310 url: "http://127.0.0.1:9".into(),
311 probe: "/".into(),
312 auth_token: String::new(),
313 }
314 }
315
316 #[test]
317 fn detects_first_reachable() {
318 let (reachable, handle) = reachable_backend();
319 let backends = vec![unreachable_backend(), reachable.clone()];
320
321 assert_eq!(detect_backend(&backends, 500), Some(reachable));
322 let request = handle.join().expect("probe request thread");
323 assert!(has_header(&request, "authorization", "Bearer token"));
324 }
325
326 #[test]
327 fn probe_url_uses_exactly_one_separator() {
328 let backend = Backend {
329 name: "test".into(),
330 url: "http://localhost:1234/".into(),
331 probe: "/v1/models".into(),
332 auth_token: String::new(),
333 };
334
335 assert_eq!(
336 backend_probe_url(&backend),
337 "http://localhost:1234/v1/models"
338 );
339
340 let backend = Backend {
341 probe: String::new(),
342 ..backend
343 };
344 assert_eq!(backend_probe_url(&backend), "http://localhost:1234");
345 }
346
347 fn has_header(request: &str, name: &str, value: &str) -> bool {
348 request.lines().any(|line| {
349 line.split_once(':').is_some_and(|(header, actual)| {
350 header.trim().eq_ignore_ascii_case(name) && actual.trim() == value
351 })
352 })
353 }
354 }
355}