provider_agent/
discovery.rs1use std::collections::HashMap;
10use std::sync::Arc;
11
12use serde::Serialize;
13use serde_json::{Value, json};
14use tracing::{debug, info, warn};
15
16use crate::backend::{
17 Backend, LlamaCppBackend, LmStudioBackend, OllamaBackend, OpenRouterBackend, VeniceBackend,
18 VllmBackend,
19};
20use crate::config::{Backend as CfgBackend, Config};
21
22const AGENT_VERSION: &str = env!("CARGO_PKG_VERSION");
23
24fn kind_priority(kind: &str) -> u8 {
27 match kind {
28 "vllm" => 0,
29 "llamacpp" => 1,
30 "lmstudio" => 2,
31 "ollama" => 3,
32 "venice" => 4,
33 "openrouter" => 5,
34 _ => 100,
35 }
36}
37
38#[derive(Debug, Clone, Serialize)]
40pub struct ResolvedModel {
41 pub model_id: String,
42 pub input_per_1m: u64,
43 pub output_per_1m: u64,
44 pub max_concurrent: u32,
45 pub backend: String,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 pub context_window: Option<u32>,
48}
49
50pub struct DiscoveredBackend {
52 pub backend: Arc<dyn Backend>,
53 pub models: Vec<ResolvedModel>,
54}
55
56pub struct DiscoveryResult {
58 pub backends: Vec<DiscoveredBackend>,
59 pub capability_models: Vec<ResolvedModel>,
61}
62
63impl DiscoveryResult {
64 pub fn to_capabilities(&self, cfg: &Config) -> Value {
66 let limits = json!({
67 "max_concurrent_total": cfg.limits.max_concurrent,
68 "max_tokens_per_minute": cfg.limits.max_tokens_per_minute,
69 });
70 json!({
71 "type": "capabilities",
72 "models": self.capability_models,
73 "limits": limits,
74 "metadata": {
75 "agent_version": AGENT_VERSION,
76 }
77 })
78 }
79}
80
81pub fn build_backend(cfg: &CfgBackend) -> Option<Arc<dyn Backend>> {
84 match cfg.kind.as_str() {
85 "vllm" => cfg
86 .url
87 .as_deref()
88 .map(|u| Arc::new(VllmBackend::new(u)) as Arc<dyn Backend>),
89 "llamacpp" => cfg
90 .url
91 .as_deref()
92 .map(|u| Arc::new(LlamaCppBackend::new(u)) as Arc<dyn Backend>),
93 "lmstudio" => cfg
94 .url
95 .as_deref()
96 .map(|u| Arc::new(LmStudioBackend::new(u)) as Arc<dyn Backend>),
97 "ollama" => cfg
98 .url
99 .as_deref()
100 .map(|u| Arc::new(OllamaBackend::new(u)) as Arc<dyn Backend>),
101 "openrouter" => match cfg
102 .api_key_env
103 .as_deref()
104 .map(OpenRouterBackend::from_env)
105 {
106 Some(Ok(b)) => Some(Arc::new(b) as Arc<dyn Backend>),
107 Some(Err(e)) => {
108 warn!(?e, "skipping openrouter backend (no api key)");
109 None
110 }
111 None => None,
112 },
113 "venice" => match cfg.api_key_env.as_deref().map(VeniceBackend::from_env) {
114 Some(Ok(b)) => Some(Arc::new(b) as Arc<dyn Backend>),
115 Some(Err(e)) => {
116 warn!(?e, "skipping venice backend (no api key)");
117 None
118 }
119 None => None,
120 },
121 other => {
122 warn!(kind = other, "unknown backend kind in config; skipping");
123 None
124 }
125 }
126}
127
128pub async fn run(cfg: &Config) -> DiscoveryResult {
132 let mut discovered: Vec<DiscoveredBackend> = Vec::new();
133
134 for cfg_backend in &cfg.backends {
135 let Some(backend) = build_backend(cfg_backend) else {
136 continue;
137 };
138
139 let health = match backend.health().await {
140 Ok(h) => h,
141 Err(e) => {
142 warn!(backend = backend.id(), ?e, "health check failed");
143 continue;
144 }
145 };
146 if !health.reachable {
147 warn!(
148 backend = backend.id(),
149 error = ?health.last_error,
150 "backend unreachable; skipping"
151 );
152 continue;
153 }
154 debug!(
155 backend = backend.id(),
156 latency_ms = ?health.latency_ms,
157 "backend healthy"
158 );
159
160 let models = match backend.list_models().await {
161 Ok(m) => m,
162 Err(e) => {
163 warn!(backend = backend.id(), ?e, "list_models failed");
164 continue;
165 }
166 };
167
168 let allow: Option<&Vec<String>> = cfg_backend.models.as_ref();
170 let resolved: Vec<ResolvedModel> = models
171 .into_iter()
172 .filter(|m| match allow {
173 Some(list) => list.iter().any(|x| x == &m.model_id),
174 None => true,
175 })
176 .filter_map(|m| {
177 let (input_per_1m, output_per_1m) = match cfg.pricing.models.get(&m.model_id) {
178 Some(p) => (p.input_per_1m, p.output_per_1m),
179 None => (
180 cfg.pricing.default_input_per_1m,
181 cfg.pricing.default_output_per_1m,
182 ),
183 };
184 if input_per_1m == 0 || output_per_1m == 0 {
185 warn!(model = %m.model_id, "no pricing; dropping model");
186 return None;
187 }
188 Some(ResolvedModel {
189 model_id: m.model_id,
190 input_per_1m,
191 output_per_1m,
192 max_concurrent: cfg.limits.max_concurrent,
193 backend: backend.kind().to_string(),
194 context_window: m.context_window,
195 })
196 })
197 .collect();
198
199 info!(
200 backend = backend.id(),
201 kind = backend.kind(),
202 models = resolved.len(),
203 "backend discovered"
204 );
205 discovered.push(DiscoveredBackend {
206 backend,
207 models: resolved,
208 });
209 }
210
211 let capability_models = dedupe_by_priority(&discovered);
212 DiscoveryResult {
213 backends: discovered,
214 capability_models,
215 }
216}
217
218fn dedupe_by_priority(discovered: &[DiscoveredBackend]) -> Vec<ResolvedModel> {
221 let mut by_id: HashMap<String, ResolvedModel> = HashMap::new();
222 for db in discovered {
223 for m in &db.models {
224 match by_id.get(&m.model_id) {
225 Some(existing) if kind_priority(&existing.backend) <= kind_priority(&m.backend) => {
226 continue;
227 }
228 _ => {
229 by_id.insert(m.model_id.clone(), m.clone());
230 }
231 }
232 }
233 }
234 let mut out: Vec<ResolvedModel> = by_id.into_values().collect();
235 out.sort_by(|a, b| a.model_id.cmp(&b.model_id));
236 out
237}