Skip to main content

convergio_provisioning/
routes.rs

1//! HTTP routes for node provisioning.
2//!
3//! - POST /api/provision/peer     — trigger provisioning for a peer
4//! - GET  /api/provision/runs     — list provisioning runs
5//! - GET  /api/provision/run/:id  — get a specific run with items
6
7use std::sync::Arc;
8
9use axum::extract::{Path, Query, State};
10use axum::response::Json;
11use axum::routing::{get, post};
12use axum::Router;
13use convergio_db::pool::ConnPool;
14use serde::Deserialize;
15use serde_json::{json, Value};
16
17use crate::provision::provision_peer;
18use crate::types::ProvisionRequest;
19
20/// Characters that could trigger shell interpretation on the remote side.
21fn has_shell_metachar(s: &str) -> bool {
22    s.contains(|c: char| {
23        matches!(
24            c,
25            ';' | '&' | '|' | '$' | '`' | '\'' | '"' | '\\' | '\n' | '\r'
26        )
27    })
28}
29
30/// Validate provisioning request fields to prevent injection/traversal.
31pub fn validate_request(req: &ProvisionRequest) -> Result<(), String> {
32    // peer_name: alphanumeric, dash, underscore, dot only
33    if req.peer_name.is_empty() || req.peer_name.len() > 128 {
34        return Err("peer_name must be 1-128 characters".into());
35    }
36    if !req
37        .peer_name
38        .chars()
39        .all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.'))
40    {
41        return Err("peer_name must contain only alphanumeric, dash, underscore, or dot".into());
42    }
43
44    // ssh_target: length + no shell metacharacters
45    if req.ssh_target.is_empty() || req.ssh_target.len() > 256 {
46        return Err("ssh_target must be 1-256 characters".into());
47    }
48    if has_shell_metachar(&req.ssh_target) {
49        return Err("ssh_target contains invalid characters".into());
50    }
51
52    // remote_base: length, no traversal, no shell metacharacters
53    if req.remote_base.is_empty() || req.remote_base.len() > 512 {
54        return Err("remote_base must be 1-512 characters".into());
55    }
56    if req.remote_base.contains("..") {
57        return Err("remote_base must not contain path traversal".into());
58    }
59    if has_shell_metachar(&req.remote_base) {
60        return Err("remote_base contains invalid characters".into());
61    }
62
63    Ok(())
64}
65
66pub struct ProvisionState {
67    pub pool: ConnPool,
68}
69
70pub fn provision_routes(state: Arc<ProvisionState>) -> Router {
71    Router::new()
72        .route("/api/provision/peer", post(handle_provision))
73        .route("/api/provision/runs", get(handle_list_runs))
74        .route("/api/provision/run/{id}", get(handle_get_run))
75        .with_state(state)
76}
77
78async fn handle_provision(
79    State(s): State<Arc<ProvisionState>>,
80    Json(req): Json<ProvisionRequest>,
81) -> Json<Value> {
82    if let Err(e) = validate_request(&req) {
83        return Json(json!({"error": e}));
84    }
85    let pool = s.pool.clone();
86    let peer = req.peer_name.clone();
87    let handle = tokio::spawn(async move {
88        match provision_peer(&pool, &req).await {
89            Ok(run_id) => tracing::info!(run_id, peer = %peer, "provisioning complete"),
90            Err(e) => tracing::warn!(peer = %peer, error = %e, "provisioning failed"),
91        }
92    });
93    // Log if the spawned task panics
94    tokio::spawn(async move {
95        if let Err(e) = handle.await {
96            tracing::error!(error = %e, "provisioning task panicked");
97        }
98    });
99    Json(json!({"ok": true, "message": "provisioning started"}))
100}
101
102#[derive(Deserialize, Default)]
103struct ListQuery {
104    limit: Option<u32>,
105}
106
107async fn handle_list_runs(
108    State(s): State<Arc<ProvisionState>>,
109    Query(q): Query<ListQuery>,
110) -> Json<Value> {
111    let conn = match s.pool.get() {
112        Ok(c) => c,
113        Err(e) => return Json(json!({"error": e.to_string()})),
114    };
115    let limit = q.limit.unwrap_or(20).min(100);
116    let mut stmt = match conn.prepare(
117        "SELECT id, peer_name, ssh_target, status, items_total, items_done, \
118         started_at, completed_at, error_message \
119         FROM provision_runs ORDER BY id DESC LIMIT ?1",
120    ) {
121        Ok(s) => s,
122        Err(e) => return Json(json!({"error": e.to_string()})),
123    };
124    let rows: Vec<Value> = stmt
125        .query_map([limit], |r| {
126            Ok(json!({
127                "id": r.get::<_, i64>(0)?,
128                "peer_name": r.get::<_, String>(1)?,
129                "ssh_target": r.get::<_, String>(2)?,
130                "status": r.get::<_, String>(3)?,
131                "items_total": r.get::<_, u32>(4)?,
132                "items_done": r.get::<_, u32>(5)?,
133                "started_at": r.get::<_, String>(6)?,
134                "completed_at": r.get::<_, Option<String>>(7)?,
135                "error": r.get::<_, Option<String>>(8)?,
136            }))
137        })
138        .map(|rows| rows.flatten().collect())
139        .unwrap_or_default();
140    Json(json!({"runs": rows}))
141}
142
143async fn handle_get_run(State(s): State<Arc<ProvisionState>>, Path(id): Path<i64>) -> Json<Value> {
144    let conn = match s.pool.get() {
145        Ok(c) => c,
146        Err(e) => return Json(json!({"error": e.to_string()})),
147    };
148    let run = conn.query_row(
149        "SELECT id, peer_name, ssh_target, status, items_total, items_done, \
150         started_at, completed_at, error_message FROM provision_runs WHERE id = ?1",
151        [id],
152        |r| {
153            Ok(json!({
154                "id": r.get::<_, i64>(0)?,
155                "peer_name": r.get::<_, String>(1)?,
156                "ssh_target": r.get::<_, String>(2)?,
157                "status": r.get::<_, String>(3)?,
158                "items_total": r.get::<_, u32>(4)?,
159                "items_done": r.get::<_, u32>(5)?,
160                "started_at": r.get::<_, String>(6)?,
161                "completed_at": r.get::<_, Option<String>>(7)?,
162                "error": r.get::<_, Option<String>>(8)?,
163            }))
164        },
165    );
166    let items = list_items(&conn, id);
167    match run {
168        Ok(r) => Json(json!({"run": r, "items": items})),
169        Err(e) => Json(json!({"error": e.to_string()})),
170    }
171}
172
173fn list_items(conn: &rusqlite::Connection, run_id: i64) -> Vec<Value> {
174    let mut stmt = match conn.prepare(
175        "SELECT id, item_type, source_path, dest_path, status, \
176         bytes_transferred, duration_ms, error_message \
177         FROM provision_items WHERE run_id = ?1 ORDER BY id",
178    ) {
179        Ok(s) => s,
180        Err(_) => return vec![],
181    };
182    stmt.query_map([run_id], |r| {
183        Ok(json!({
184            "id": r.get::<_, i64>(0)?,
185            "item_type": r.get::<_, String>(1)?,
186            "source": r.get::<_, String>(2)?,
187            "dest": r.get::<_, String>(3)?,
188            "status": r.get::<_, String>(4)?,
189            "bytes": r.get::<_, u64>(5)?,
190            "duration_ms": r.get::<_, u64>(6)?,
191            "error": r.get::<_, Option<String>>(7)?,
192        }))
193    })
194    .map(|rows| rows.flatten().collect())
195    .unwrap_or_default()
196}