Skip to main content

gobby_core/
local_backend.rs

1use serde::Deserialize;
2
3#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
4pub struct Backend {
5    pub name: String,
6    pub url: String,
7    pub probe: String,
8    #[serde(default)]
9    pub auth_token: String,
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct BackendDefault {
14    pub name: &'static str,
15    pub url: &'static str,
16    pub probe: &'static str,
17    pub auth_token: &'static str,
18}
19
20impl BackendDefault {
21    pub fn to_backend(self) -> Backend {
22        Backend {
23            name: self.name.to_string(),
24            url: self.url.to_string(),
25            probe: self.probe.to_string(),
26            auth_token: self.auth_token.to_string(),
27        }
28    }
29}
30
31pub const DEFAULT_BACKENDS: &[BackendDefault] = &[
32    BackendDefault {
33        name: "lmstudio",
34        url: "http://localhost:1234",
35        probe: "/v1/models",
36        auth_token: "",
37    },
38    BackendDefault {
39        name: "ollama",
40        url: "http://localhost:11434",
41        probe: "/api/tags",
42        auth_token: "",
43    },
44];
45
46pub fn default_backends() -> Vec<Backend> {
47    DEFAULT_BACKENDS
48        .iter()
49        .copied()
50        .map(BackendDefault::to_backend)
51        .collect()
52}
53
54pub fn backend_api_base(backend: &Backend) -> String {
55    format!("{}/v1", backend.url.trim_end_matches('/'))
56}
57
58/// Probe backends in order, return the first that responds successfully.
59pub fn detect_backend(backends: &[Backend], timeout_ms: u64) -> Option<Backend> {
60    for backend in backends {
61        if validate_backend(backend, timeout_ms) {
62            return Some(backend.clone());
63        }
64    }
65    None
66}
67
68/// Validate that a specific backend is reachable.
69pub fn validate_backend(backend: &Backend, timeout_ms: u64) -> bool {
70    let timeout = std::time::Duration::from_millis(timeout_ms);
71    let url = backend_probe_url(backend);
72    let agent = ureq::AgentBuilder::new()
73        .timeout_connect(timeout)
74        .timeout_read(timeout)
75        .build();
76    let mut request = agent.get(&url);
77    let auth_header;
78    let token = backend.auth_token.trim();
79    if !token.is_empty() {
80        auth_header = format!("Bearer {token}");
81        request = request.set("Authorization", &auth_header);
82    }
83    match request.call() {
84        Ok(_) => true,
85        Err(ureq::Error::Status(status, response)) => {
86            log::trace!(
87                "local backend probe `{}` at {} returned HTTP {} {}",
88                backend.name,
89                url,
90                status,
91                response.status_text()
92            );
93            false
94        }
95        Err(error) => {
96            log::trace!(
97                "local backend probe `{}` at {} failed: {}",
98                backend.name,
99                url,
100                error
101            );
102            false
103        }
104    }
105}
106
107fn backend_probe_url(backend: &Backend) -> String {
108    let base = backend.url.trim_end_matches('/');
109    let probe = backend.probe.trim_start_matches('/');
110    if probe.is_empty() {
111        base.to_string()
112    } else {
113        format!("{base}/{probe}")
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use std::io::{Read, Write};
121    use std::net::TcpListener;
122    use std::thread;
123
124    fn reachable_backend() -> (Backend, thread::JoinHandle<String>) {
125        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
126        let addr = listener.local_addr().unwrap();
127        let handle = thread::spawn(move || {
128            if let Ok((mut stream, _)) = listener.accept() {
129                let mut buffer = [0_u8; 1024];
130                let read = stream.read(&mut buffer).unwrap_or(0);
131                let _ = stream.write_all(
132                    b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\n{}",
133                );
134                return String::from_utf8_lossy(&buffer[..read]).to_string();
135            }
136            String::new()
137        });
138
139        (
140            Backend {
141                name: "reachable".into(),
142                url: format!("http://{}", addr),
143                probe: "/v1/models".into(),
144                auth_token: "token".into(),
145            },
146            handle,
147        )
148    }
149
150    fn unreachable_backend() -> Backend {
151        Backend {
152            name: "unreachable".into(),
153            url: "http://127.0.0.1:9".into(),
154            probe: "/".into(),
155            auth_token: String::new(),
156        }
157    }
158
159    #[test]
160    fn detects_first_reachable() {
161        let (reachable, handle) = reachable_backend();
162        let backends = vec![unreachable_backend(), reachable.clone()];
163
164        assert_eq!(detect_backend(&backends, 500), Some(reachable));
165        let request = handle.join().expect("probe request thread");
166        assert!(has_header(&request, "authorization", "Bearer token"));
167    }
168
169    #[test]
170    fn default_local_backends_do_not_send_auth_tokens() {
171        assert!(
172            default_backends()
173                .iter()
174                .all(|backend| backend.auth_token.is_empty())
175        );
176    }
177
178    #[test]
179    fn probe_url_uses_exactly_one_separator() {
180        let backend = Backend {
181            name: "test".into(),
182            url: "http://localhost:1234/".into(),
183            probe: "/v1/models".into(),
184            auth_token: String::new(),
185        };
186
187        assert_eq!(
188            backend_probe_url(&backend),
189            "http://localhost:1234/v1/models"
190        );
191
192        let backend = Backend {
193            probe: String::new(),
194            ..backend
195        };
196        assert_eq!(backend_probe_url(&backend), "http://localhost:1234");
197    }
198
199    fn has_header(request: &str, name: &str, value: &str) -> bool {
200        request.lines().any(|line| {
201            line.split_once(':').is_some_and(|(header, actual)| {
202                header.trim().eq_ignore_ascii_case(name) && actual.trim() == value
203            })
204        })
205    }
206}