Skip to main content

gobby_core/
local_backend.rs

1#[cfg(feature = "local-backend")]
2use std::io::{self, Read, Write};
3#[cfg(feature = "local-backend")]
4use std::net::{TcpStream, ToSocketAddrs};
5#[cfg(feature = "local-backend")]
6use std::time::Duration;
7
8use serde::Deserialize;
9
10#[cfg(feature = "local-backend")]
11const MAX_PROBE_RESPONSE_BYTES: usize = 1024;
12
13#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
14pub struct Backend {
15    pub name: String,
16    pub url: String,
17    pub probe: String,
18    #[serde(default)]
19    pub auth_token: String,
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub struct BackendDefault {
24    pub name: &'static str,
25    pub url: &'static str,
26    pub probe: &'static str,
27    pub auth_token: &'static str,
28}
29
30impl BackendDefault {
31    pub fn to_backend(self) -> Backend {
32        Backend {
33            name: self.name.to_string(),
34            url: self.url.to_string(),
35            probe: self.probe.to_string(),
36            auth_token: self.auth_token.to_string(),
37        }
38    }
39}
40
41pub const DEFAULT_BACKENDS: &[BackendDefault] = &[
42    BackendDefault {
43        name: "lmstudio",
44        url: "http://localhost:1234",
45        probe: "/v1/models",
46        auth_token: "",
47    },
48    BackendDefault {
49        name: "ollama",
50        url: "http://localhost:11434",
51        probe: "/api/tags",
52        auth_token: "",
53    },
54];
55
56pub fn default_backends() -> Vec<Backend> {
57    DEFAULT_BACKENDS
58        .iter()
59        .copied()
60        .map(BackendDefault::to_backend)
61        .collect()
62}
63
64pub fn backend_api_base(backend: &Backend) -> String {
65    format!("{}/v1", backend.url.trim_end_matches('/'))
66}
67
68/// Probe backends in order, return the first that responds successfully.
69#[cfg(feature = "local-backend")]
70pub fn detect_backend(backends: &[Backend], timeout_ms: u64) -> Option<Backend> {
71    for backend in backends {
72        if validate_backend(backend, timeout_ms) {
73            return Some(backend.clone());
74        }
75    }
76    None
77}
78
79/// Validate that a specific backend is reachable.
80#[cfg(feature = "local-backend")]
81pub fn validate_backend(backend: &Backend, timeout_ms: u64) -> bool {
82    let timeout = Duration::from_millis(timeout_ms);
83    let url = backend_probe_url(backend);
84    let Some(target) = HttpProbeTarget::parse(&url) else {
85        log::trace!(
86            "local backend probe `{}` at {} failed: unsupported HTTP URL",
87            backend.name,
88            url
89        );
90        return false;
91    };
92
93    match send_probe_request(&target, backend.auth_token.trim(), timeout) {
94        Ok(status) if (200..300).contains(&status) => true,
95        Ok(status) => {
96            log::trace!(
97                "local backend probe `{}` at {} returned HTTP {}",
98                backend.name,
99                url,
100                status
101            );
102            false
103        }
104        Err(error) => {
105            log::trace!(
106                "local backend probe `{}` at {} failed: {}",
107                backend.name,
108                url,
109                error
110            );
111            false
112        }
113    }
114}
115
116#[cfg(feature = "local-backend")]
117#[derive(Debug, Clone, PartialEq, Eq)]
118struct HttpProbeTarget {
119    host: String,
120    port: u16,
121    path: String,
122}
123
124#[cfg(feature = "local-backend")]
125impl HttpProbeTarget {
126    fn parse(url: &str) -> Option<Self> {
127        let rest = url.strip_prefix("http://")?;
128        let (authority, path) = rest.split_once('/').unwrap_or((rest, ""));
129        let (host, port) = parse_http_authority(authority)?;
130        Some(Self {
131            host,
132            port,
133            path: format!("/{path}"),
134        })
135    }
136
137    fn socket_addr(&self) -> String {
138        if self.host.contains(':') {
139            format!("[{}]:{}", self.host, self.port)
140        } else {
141            format!("{}:{}", self.host, self.port)
142        }
143    }
144
145    fn host_header(&self) -> String {
146        if self.port == 80 {
147            self.host.clone()
148        } else if self.host.contains(':') {
149            format!("[{}]:{}", self.host, self.port)
150        } else {
151            format!("{}:{}", self.host, self.port)
152        }
153    }
154}
155
156#[cfg(feature = "local-backend")]
157fn parse_http_authority(authority: &str) -> Option<(String, u16)> {
158    if authority.is_empty() || authority.contains('@') {
159        return None;
160    }
161    if let Some(rest) = authority.strip_prefix('[') {
162        let (host, suffix) = rest.split_once(']')?;
163        if host.is_empty() {
164            return None;
165        }
166        let port = if suffix.is_empty() {
167            80
168        } else {
169            suffix.strip_prefix(':')?.parse().ok()?
170        };
171        return Some((host.to_string(), port));
172    }
173    if authority.contains('[') || authority.contains(']') {
174        return None;
175    }
176    match authority.rsplit_once(':') {
177        Some((host, port)) if !host.contains(':') && !host.is_empty() => {
178            Some((host.to_string(), port.parse().ok()?))
179        }
180        Some(_) => None,
181        None => Some((authority.to_string(), 80)),
182    }
183}
184
185#[cfg(feature = "local-backend")]
186fn send_probe_request(
187    target: &HttpProbeTarget,
188    auth_token: &str,
189    timeout: Duration,
190) -> io::Result<u16> {
191    let addr = target
192        .socket_addr()
193        .to_socket_addrs()?
194        .next()
195        .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, "no resolved address"))?;
196    let mut stream = TcpStream::connect_timeout(&addr, timeout)?;
197    stream.set_read_timeout(Some(timeout))?;
198    stream.set_write_timeout(Some(timeout))?;
199
200    let mut request = format!(
201        "GET {} HTTP/1.1\r\nHost: {}\r\nUser-Agent: gobby-core\r\nConnection: close\r\n",
202        target.path,
203        target.host_header()
204    );
205    if !auth_token.is_empty() {
206        request.push_str("Authorization: Bearer ");
207        request.push_str(auth_token);
208        request.push_str("\r\n");
209    }
210    request.push_str("\r\n");
211    stream.write_all(request.as_bytes())?;
212
213    let mut response = Vec::new();
214    let mut chunk = [0_u8; 128];
215    while response.len() < MAX_PROBE_RESPONSE_BYTES {
216        let read = match stream.read(&mut chunk) {
217            Ok(read) => read,
218            Err(error)
219                if matches!(
220                    error.kind(),
221                    io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut
222                ) && !response.is_empty() =>
223            {
224                break;
225            }
226            Err(error) => return Err(error),
227        };
228        if read == 0 {
229            break;
230        }
231        response.extend_from_slice(&chunk[..read]);
232        if response.contains(&b'\n') {
233            break;
234        }
235    }
236    parse_http_status(&response)
237}
238
239#[cfg(feature = "local-backend")]
240fn parse_http_status(response: &[u8]) -> io::Result<u16> {
241    let response = String::from_utf8_lossy(response);
242    response
243        .lines()
244        .next()
245        .and_then(|line| line.split_whitespace().nth(1))
246        .and_then(|status| status.parse().ok())
247        .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "missing HTTP status"))
248}
249
250#[cfg(feature = "local-backend")]
251fn backend_probe_url(backend: &Backend) -> String {
252    let base = backend.url.trim_end_matches('/');
253    let probe = backend.probe.trim_start_matches('/');
254    if probe.is_empty() {
255        base.to_string()
256    } else {
257        format!("{base}/{probe}")
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    #[test]
266    fn default_local_backends_do_not_send_auth_tokens() {
267        assert!(
268            default_backends()
269                .iter()
270                .all(|backend| backend.auth_token.is_empty())
271        );
272    }
273
274    #[cfg(feature = "local-backend")]
275    mod http {
276        use super::*;
277        use std::io::{Read, Write};
278        use std::net::TcpListener;
279        use std::thread;
280
281        fn reachable_backend() -> (Backend, thread::JoinHandle<String>) {
282            let listener = TcpListener::bind("127.0.0.1:0").unwrap();
283            let addr = listener.local_addr().unwrap();
284            let handle = thread::spawn(move || {
285                if let Ok((mut stream, _)) = listener.accept() {
286                    let mut buffer = [0_u8; 1024];
287                    let read = stream.read(&mut buffer).unwrap_or(0);
288                    let _ = stream.write_all(
289                        b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\n{}",
290                    );
291                    return String::from_utf8_lossy(&buffer[..read]).to_string();
292                }
293                String::new()
294            });
295
296            (
297                Backend {
298                    name: "reachable".into(),
299                    url: format!("http://{}", addr),
300                    probe: "/v1/models".into(),
301                    auth_token: "token".into(),
302                },
303                handle,
304            )
305        }
306
307        fn unreachable_backend() -> Backend {
308            Backend {
309                name: "unreachable".into(),
310                url: "http://127.0.0.1:9".into(),
311                probe: "/".into(),
312                auth_token: String::new(),
313            }
314        }
315
316        #[test]
317        fn detects_first_reachable() {
318            let (reachable, handle) = reachable_backend();
319            let backends = vec![unreachable_backend(), reachable.clone()];
320
321            assert_eq!(detect_backend(&backends, 500), Some(reachable));
322            let request = handle.join().expect("probe request thread");
323            assert!(has_header(&request, "authorization", "Bearer token"));
324        }
325
326        #[test]
327        fn probe_url_uses_exactly_one_separator() {
328            let backend = Backend {
329                name: "test".into(),
330                url: "http://localhost:1234/".into(),
331                probe: "/v1/models".into(),
332                auth_token: String::new(),
333            };
334
335            assert_eq!(
336                backend_probe_url(&backend),
337                "http://localhost:1234/v1/models"
338            );
339
340            let backend = Backend {
341                probe: String::new(),
342                ..backend
343            };
344            assert_eq!(backend_probe_url(&backend), "http://localhost:1234");
345        }
346
347        fn has_header(request: &str, name: &str, value: &str) -> bool {
348            request.lines().any(|line| {
349                line.split_once(':').is_some_and(|(header, actual)| {
350                    header.trim().eq_ignore_ascii_case(name) && actual.trim() == value
351                })
352            })
353        }
354    }
355}