Skip to main content

provider_agent/
setup.rs

1//! `usepod-agent setup` — device-flow pairing on first run.
2//!
3//! Replaces the v0.1.x ceremony of "enroll on the dashboard, copy a
4//! 40-char host_token, paste it into agent.toml". The flow now:
5//!
6//!   1. Generate (or load) the agent's Ed25519 identity.
7//!   2. Probe well-known local backend ports (vLLM :8000, llama.cpp
8//!      :8080, LM Studio :1234, Ollama :11434).
9//!   3. POST /v1/host/pair/issue to the coordinator with the agent
10//!      pubkey; receive a short pair_code + poll_token.
11//!   4. Print the pair_code prominently and start long-polling
12//!      /v1/host/pair/poll, sending detected backends as capabilities
13//!      so the dashboard can render the model picker live.
14//!   5. When the operator hits Activate in the dashboard, the next poll
15//!      response delivers host_token + provider_id + activated_models.
16//!   6. Write a complete agent.toml from the discovered + operator-
17//!      configured state. Operator never edits a file.
18
19use std::path::{Path, PathBuf};
20use std::time::Duration;
21
22use anyhow::{Context, Result, anyhow, bail};
23use base64::Engine;
24use reqwest::Client;
25use serde::{Deserialize, Serialize};
26use serde_json::Value;
27use tracing::{debug, info, warn};
28
29use crate::identity::Identity;
30
31const DEFAULT_COORDINATOR: &str = "https://api.usepod.ai";
32const POLL_INTERVAL: Duration = Duration::from_millis(800);
33
34#[derive(Debug, Clone)]
35pub struct SetupArgs {
36    pub coordinator: String,
37    pub config_path: PathBuf,
38    pub identity_path: PathBuf,
39}
40
41impl SetupArgs {
42    pub fn defaults() -> Result<Self> {
43        let config_path = default_config_path()?;
44        let identity_path = default_identity_path()?;
45        Ok(Self {
46            coordinator: DEFAULT_COORDINATOR.into(),
47            config_path,
48            identity_path,
49        })
50    }
51}
52
53pub async fn run(args: SetupArgs) -> Result<()> {
54    println!("usepod-agent setup");
55    println!();
56
57    // Identity — generate if missing. Persisting the keypair before pairing
58    // means a re-run of `setup` after a partial pair stays continuous.
59    let identity = crate::identity::load_or_create(&args.identity_path)
60        .context("identity load/create")?;
61    info!(public_key = %identity.public_key_b64(), "identity ready");
62
63    // Backend autodetection. Probes well-known ports with a short timeout.
64    let backends = probe_local_backends().await;
65    if backends.is_empty() {
66        println!("No local backends detected on standard ports.");
67        println!("That's OK — you can pair anyway and configure backends later");
68        println!("from the dashboard, or install one of:");
69        println!("  - Ollama:    https://ollama.ai");
70        println!("  - vLLM:      pip install vllm && vllm serve <model>");
71        println!("  - llama.cpp: ./llama-server -m <model>.gguf");
72        println!();
73    } else {
74        println!("Detected backends:");
75        for b in &backends {
76            println!(
77                "  ✓ {:<10} at {} ({} models)",
78                b.kind,
79                b.url,
80                b.models.len()
81            );
82        }
83        println!();
84    }
85
86    // Issue pair code.
87    let http = Client::builder()
88        .timeout(Duration::from_secs(30))
89        .build()?;
90    let issue = issue_pair_code(&http, &args.coordinator, &identity).await?;
91
92    print_pair_banner(&issue.pair_code, &args.coordinator);
93
94    // Long-poll for claim.
95    let active = match poll_until_active(&http, &args.coordinator, &issue.poll_token, &backends)
96        .await?
97    {
98        PollOutcome::Active(a) => a,
99        PollOutcome::Expired => {
100            println!();
101            println!("✗ Pair code expired. Run `usepod-agent setup` again.");
102            std::process::exit(1);
103        }
104    };
105
106    println!();
107    println!("✓ Paired as provider {}", active.provider_id);
108    if !active.activated_models.is_empty() {
109        println!("  Operator activated {} model(s):", active.activated_models.len());
110        for m in &active.activated_models {
111            println!("    - {}", m.model_id);
112        }
113    }
114
115    // Write complete agent.toml.
116    let toml_text = render_paired_config(&args, &identity, &backends, &active);
117    if let Some(parent) = args.config_path.parent() {
118        std::fs::create_dir_all(parent).context("create config dir")?;
119    }
120    std::fs::write(&args.config_path, toml_text).context("write agent.toml")?;
121    println!();
122    println!("Wrote {}", args.config_path.display());
123    println!();
124    println!("Run the agent:");
125    println!("  usepod-agent run");
126    println!();
127
128    Ok(())
129}
130
131// ---------------------------------------------------------------------------
132// Pair code banner
133// ---------------------------------------------------------------------------
134
135fn print_pair_banner(pair_code: &str, coordinator: &str) {
136    let pair_url = if coordinator.contains("api.usepod.ai") {
137        "https://usepod.ai/host/pair".to_string()
138    } else {
139        format!("{}/host/pair", coordinator.trim_end_matches('/'))
140    };
141    println!();
142    println!("┌─────────────────────────────────────────────────────────┐");
143    println!("│ Pair this agent with your Use Pod account:              │");
144    println!("│                                                         │");
145    println!("│   1. Visit  {:43} │", pair_url);
146    println!("│   2. Code:  {:<43} │", pair_code);
147    println!("│                                                         │");
148    println!("│ Code expires in 10 minutes. Waiting for pairing...      │");
149    println!("└─────────────────────────────────────────────────────────┘");
150    println!();
151}
152
153// ---------------------------------------------------------------------------
154// HTTP — pair issue / poll
155// ---------------------------------------------------------------------------
156
157#[derive(Debug, Deserialize)]
158struct IssueResponse {
159    pair_code: String,
160    poll_token: String,
161    #[allow(dead_code)]
162    expires_at: String,
163}
164
165async fn issue_pair_code(
166    http: &Client,
167    coordinator: &str,
168    identity: &Identity,
169) -> Result<IssueResponse> {
170    let url = format!("{}/v1/host/pair/issue", coordinator.trim_end_matches('/'));
171    let body = serde_json::json!({
172        "agent_pubkey": identity.public_key_b64(),
173    });
174    let resp = http
175        .post(&url)
176        .json(&body)
177        .send()
178        .await
179        .context("POST /v1/host/pair/issue")?
180        .error_for_status()
181        .context("issue response status")?;
182    let parsed: IssueResponse = resp.json().await.context("issue response parse")?;
183    Ok(parsed)
184}
185
186#[derive(Debug, Clone)]
187struct ActivePairing {
188    #[allow(dead_code)]
189    host_token: String,
190    provider_id: String,
191    activated_models: Vec<ActivatedModel>,
192}
193
194enum PollOutcome {
195    Active(ActivePairing),
196    Expired,
197}
198
199#[derive(Debug, Deserialize, Clone)]
200struct ActivatedModel {
201    model_id: String,
202    #[serde(default)]
203    input_per_1m: u64,
204    #[serde(default)]
205    output_per_1m: u64,
206    #[serde(default = "default_max_concurrent_dl")]
207    max_concurrent: u32,
208}
209
210fn default_max_concurrent_dl() -> u32 {
211    4
212}
213
214async fn poll_until_active(
215    http: &Client,
216    coordinator: &str,
217    poll_token: &str,
218    backends: &[ProbedBackend],
219) -> Result<PollOutcome> {
220    let url = format!("{}/v1/host/pair/poll", coordinator.trim_end_matches('/'));
221    let capabilities = capabilities_payload(backends);
222    loop {
223        let body = serde_json::json!({
224            "poll_token":   poll_token,
225            "capabilities": capabilities,
226        });
227        let resp = http
228            .post(&url)
229            .json(&body)
230            .send()
231            .await
232            .context("POST /v1/host/pair/poll")?
233            .error_for_status()
234            .context("poll response status")?;
235        let v: Value = resp.json().await.context("poll response parse")?;
236        let status = v.get("status").and_then(|s| s.as_str()).unwrap_or("");
237        match status {
238            "pending" => {
239                tokio::time::sleep(POLL_INTERVAL).await;
240                continue;
241            }
242            "expired" => return Ok(PollOutcome::Expired),
243            "active" => {
244                let host_token = v
245                    .get("host_token")
246                    .and_then(|s| s.as_str())
247                    .ok_or_else(|| anyhow!("active response missing host_token"))?
248                    .to_string();
249                let provider_id = v
250                    .get("provider_id")
251                    .and_then(|s| s.as_str())
252                    .ok_or_else(|| anyhow!("active response missing provider_id"))?
253                    .to_string();
254                let activated_models: Vec<ActivatedModel> = v
255                    .get("model_config")
256                    .cloned()
257                    .and_then(|mc| serde_json::from_value(mc).ok())
258                    .unwrap_or_default();
259                return Ok(PollOutcome::Active(ActivePairing {
260                    host_token,
261                    provider_id,
262                    activated_models,
263                }));
264            }
265            other => bail!("unexpected poll status: {other}"),
266        }
267    }
268}
269
270fn capabilities_payload(backends: &[ProbedBackend]) -> Value {
271    let backends_json: Vec<Value> = backends
272        .iter()
273        .map(|b| {
274            serde_json::json!({
275                "kind":   b.kind,
276                "url":    b.url,
277                "models": b.models,
278            })
279        })
280        .collect();
281    serde_json::json!({
282        "backends":      backends_json,
283        "agent_version": env!("CARGO_PKG_VERSION"),
284    })
285}
286
287// ---------------------------------------------------------------------------
288// Backend autodetection
289// ---------------------------------------------------------------------------
290
291#[derive(Debug, Clone, Serialize)]
292pub struct ProbedBackend {
293    pub kind: String,
294    pub url: String,
295    pub models: Vec<String>,
296}
297
298const PROBE_TIMEOUT: Duration = Duration::from_millis(800);
299
300pub async fn probe_local_backends() -> Vec<ProbedBackend> {
301    let probes = vec![
302        ("vllm",     "http://localhost:8000",  probe_openai_compat as ProbeFn),
303        ("llamacpp", "http://localhost:8080",  probe_openai_compat),
304        ("lmstudio", "http://localhost:1234",  probe_openai_compat),
305        ("ollama",   "http://localhost:11434", probe_ollama),
306    ];
307    let mut out = Vec::new();
308    for (kind, url, probe) in probes {
309        match probe(url).await {
310            Ok(models) if !models.is_empty() => {
311                out.push(ProbedBackend {
312                    kind: kind.into(),
313                    url: url.into(),
314                    models,
315                });
316            }
317            Ok(_) => debug!(kind, url, "backend reachable but no models"),
318            Err(e) => debug!(kind, url, %e, "backend probe failed"),
319        }
320    }
321    out
322}
323
324type ProbeFn = fn(&str) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<String>>> + Send>>;
325
326fn probe_openai_compat(
327    url: &str,
328) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<String>>> + Send>> {
329    let url = url.to_string();
330    Box::pin(async move {
331        let http = Client::builder().timeout(PROBE_TIMEOUT).build()?;
332        let v: Value = http
333            .get(format!("{url}/v1/models"))
334            .send()
335            .await?
336            .error_for_status()?
337            .json()
338            .await?;
339        let models = v
340            .get("data")
341            .and_then(|d| d.as_array())
342            .map(|arr| {
343                arr.iter()
344                    .filter_map(|m| m.get("id").and_then(|i| i.as_str()).map(String::from))
345                    .collect()
346            })
347            .unwrap_or_default();
348        Ok(models)
349    })
350}
351
352fn probe_ollama(
353    url: &str,
354) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Vec<String>>> + Send>> {
355    let url = url.to_string();
356    Box::pin(async move {
357        let http = Client::builder().timeout(PROBE_TIMEOUT).build()?;
358        let v: Value = http
359            .get(format!("{url}/api/tags"))
360            .send()
361            .await?
362            .error_for_status()?
363            .json()
364            .await?;
365        let models = v
366            .get("models")
367            .and_then(|d| d.as_array())
368            .map(|arr| {
369                arr.iter()
370                    .filter_map(|m| m.get("name").and_then(|i| i.as_str()).map(String::from))
371                    .collect()
372            })
373            .unwrap_or_default();
374        Ok(models)
375    })
376}
377
378// ---------------------------------------------------------------------------
379// Config writer
380// ---------------------------------------------------------------------------
381
382fn render_paired_config(
383    args: &SetupArgs,
384    identity: &Identity,
385    backends: &[ProbedBackend],
386    active: &ActivePairing,
387) -> String {
388    let mut s = String::new();
389    s.push_str("# usepod-agent config — generated by `usepod-agent setup`.\n");
390    s.push_str("# Re-run setup to refresh; or hand-edit, the agent will respect it.\n\n");
391
392    s.push_str("[operator]\n");
393    s.push_str("# operator identity is owned by the dashboard since v0.2.0;\n");
394    s.push_str("# this section is preserved for back-compat with v0.1.x.\n");
395    s.push_str("display_name  = \"\"\n");
396    s.push_str("wallet        = \"\"\n\n");
397
398    s.push_str("[coordinator]\n");
399    let ws_url = http_to_ws(&args.coordinator);
400    s.push_str(&format!("url             = \"{ws_url}/provider/connect\"\n"));
401    s.push_str(&format!("# host_token     = \"{}\"  (paired)\n", short_secret(&active.host_token)));
402    s.push_str(&format!("# provider_id   = \"{}\"\n\n", active.provider_id));
403
404    s.push_str("[identity]\n");
405    s.push_str(&format!(
406        "key_path = {:?}\n\n",
407        args.identity_path.display().to_string()
408    ));
409    s.push_str(&format!(
410        "# public_key = \"{}\"\n\n",
411        identity.public_key_b64()
412    ));
413
414    for b in backends {
415        s.push_str("[[backends]]\n");
416        s.push_str(&format!("kind = \"{}\"\n", b.kind));
417        s.push_str(&format!("url  = \"{}\"\n\n", b.url));
418    }
419
420    s.push_str("[pricing]\n");
421    if active.activated_models.is_empty() {
422        s.push_str("default_input_per_1m  = 500_000   # placeholder $0.50/M\n");
423        s.push_str("default_output_per_1m = 1_000_000 # placeholder $1.00/M\n\n");
424    } else {
425        s.push_str("default_input_per_1m  = 500_000\n");
426        s.push_str("default_output_per_1m = 1_000_000\n\n");
427        for m in &active.activated_models {
428            s.push_str(&format!("[pricing.models.{:?}]\n", m.model_id));
429            s.push_str(&format!("input_per_1m  = {}\n", m.input_per_1m));
430            s.push_str(&format!("output_per_1m = {}\n\n", m.output_per_1m));
431        }
432    }
433
434    s.push_str("[limits]\n");
435    let max_concurrent = active
436        .activated_models
437        .first()
438        .map(|m| m.max_concurrent)
439        .unwrap_or(4);
440    s.push_str(&format!("max_concurrent = {max_concurrent}\n"));
441
442    s
443}
444
445fn http_to_ws(url: &str) -> String {
446    let trimmed = url.trim_end_matches('/');
447    if let Some(rest) = trimmed.strip_prefix("https://") {
448        format!("wss://{rest}")
449    } else if let Some(rest) = trimmed.strip_prefix("http://") {
450        format!("ws://{rest}")
451    } else {
452        trimmed.to_string()
453    }
454}
455
456fn short_secret(s: &str) -> String {
457    if s.len() <= 16 {
458        s.to_string()
459    } else {
460        format!("{}…{}", &s[..8], &s[s.len() - 4..])
461    }
462}
463
464// ---------------------------------------------------------------------------
465// Default paths
466// ---------------------------------------------------------------------------
467
468fn default_config_path() -> Result<PathBuf> {
469    let dirs = directories::ProjectDirs::from("ai", "usepod", "usepod-agent")
470        .ok_or_else(|| anyhow!("could not resolve config home"))?;
471    Ok(dirs.config_dir().join("agent.toml"))
472}
473
474fn default_identity_path() -> Result<PathBuf> {
475    let home = std::env::var("HOME").unwrap_or_default();
476    Ok(Path::new(&home).join(".usepod-agent").join("identity.key"))
477}