1use std::sync::Arc;
8
9use axum::extract::{Path, Query, State};
10use axum::response::Json;
11use axum::routing::{get, post};
12use axum::Router;
13use convergio_db::pool::ConnPool;
14use serde::Deserialize;
15use serde_json::{json, Value};
16
17use crate::provision::provision_peer;
18use crate::types::ProvisionRequest;
19
20fn has_shell_metachar(s: &str) -> bool {
22 s.contains(|c: char| {
23 matches!(
24 c,
25 ';' | '&' | '|' | '$' | '`' | '\'' | '"' | '\\' | '\n' | '\r'
26 )
27 })
28}
29
30pub fn validate_request(req: &ProvisionRequest) -> Result<(), String> {
32 if req.peer_name.is_empty() || req.peer_name.len() > 128 {
34 return Err("peer_name must be 1-128 characters".into());
35 }
36 if !req
37 .peer_name
38 .chars()
39 .all(|c| c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.'))
40 {
41 return Err("peer_name must contain only alphanumeric, dash, underscore, or dot".into());
42 }
43
44 if req.ssh_target.is_empty() || req.ssh_target.len() > 256 {
46 return Err("ssh_target must be 1-256 characters".into());
47 }
48 if has_shell_metachar(&req.ssh_target) {
49 return Err("ssh_target contains invalid characters".into());
50 }
51
52 if req.remote_base.is_empty() || req.remote_base.len() > 512 {
54 return Err("remote_base must be 1-512 characters".into());
55 }
56 if req.remote_base.contains("..") {
57 return Err("remote_base must not contain path traversal".into());
58 }
59 if has_shell_metachar(&req.remote_base) {
60 return Err("remote_base contains invalid characters".into());
61 }
62
63 Ok(())
64}
65
66pub struct ProvisionState {
67 pub pool: ConnPool,
68}
69
70pub fn provision_routes(state: Arc<ProvisionState>) -> Router {
71 Router::new()
72 .route("/api/provision/peer", post(handle_provision))
73 .route("/api/provision/runs", get(handle_list_runs))
74 .route("/api/provision/run/{id}", get(handle_get_run))
75 .with_state(state)
76}
77
78async fn handle_provision(
79 State(s): State<Arc<ProvisionState>>,
80 Json(req): Json<ProvisionRequest>,
81) -> Json<Value> {
82 if let Err(e) = validate_request(&req) {
83 return Json(json!({"error": e}));
84 }
85 let pool = s.pool.clone();
86 let peer = req.peer_name.clone();
87 let handle = tokio::spawn(async move {
88 match provision_peer(&pool, &req).await {
89 Ok(run_id) => tracing::info!(run_id, peer = %peer, "provisioning complete"),
90 Err(e) => tracing::warn!(peer = %peer, error = %e, "provisioning failed"),
91 }
92 });
93 tokio::spawn(async move {
95 if let Err(e) = handle.await {
96 tracing::error!(error = %e, "provisioning task panicked");
97 }
98 });
99 Json(json!({"ok": true, "message": "provisioning started"}))
100}
101
102#[derive(Deserialize, Default)]
103struct ListQuery {
104 limit: Option<u32>,
105}
106
107async fn handle_list_runs(
108 State(s): State<Arc<ProvisionState>>,
109 Query(q): Query<ListQuery>,
110) -> Json<Value> {
111 let conn = match s.pool.get() {
112 Ok(c) => c,
113 Err(e) => return Json(json!({"error": e.to_string()})),
114 };
115 let limit = q.limit.unwrap_or(20).min(100);
116 let mut stmt = match conn.prepare(
117 "SELECT id, peer_name, ssh_target, status, items_total, items_done, \
118 started_at, completed_at, error_message \
119 FROM provision_runs ORDER BY id DESC LIMIT ?1",
120 ) {
121 Ok(s) => s,
122 Err(e) => return Json(json!({"error": e.to_string()})),
123 };
124 let rows: Vec<Value> = stmt
125 .query_map([limit], |r| {
126 Ok(json!({
127 "id": r.get::<_, i64>(0)?,
128 "peer_name": r.get::<_, String>(1)?,
129 "ssh_target": r.get::<_, String>(2)?,
130 "status": r.get::<_, String>(3)?,
131 "items_total": r.get::<_, u32>(4)?,
132 "items_done": r.get::<_, u32>(5)?,
133 "started_at": r.get::<_, String>(6)?,
134 "completed_at": r.get::<_, Option<String>>(7)?,
135 "error": r.get::<_, Option<String>>(8)?,
136 }))
137 })
138 .map(|rows| rows.flatten().collect())
139 .unwrap_or_default();
140 Json(json!({"runs": rows}))
141}
142
143async fn handle_get_run(State(s): State<Arc<ProvisionState>>, Path(id): Path<i64>) -> Json<Value> {
144 let conn = match s.pool.get() {
145 Ok(c) => c,
146 Err(e) => return Json(json!({"error": e.to_string()})),
147 };
148 let run = conn.query_row(
149 "SELECT id, peer_name, ssh_target, status, items_total, items_done, \
150 started_at, completed_at, error_message FROM provision_runs WHERE id = ?1",
151 [id],
152 |r| {
153 Ok(json!({
154 "id": r.get::<_, i64>(0)?,
155 "peer_name": r.get::<_, String>(1)?,
156 "ssh_target": r.get::<_, String>(2)?,
157 "status": r.get::<_, String>(3)?,
158 "items_total": r.get::<_, u32>(4)?,
159 "items_done": r.get::<_, u32>(5)?,
160 "started_at": r.get::<_, String>(6)?,
161 "completed_at": r.get::<_, Option<String>>(7)?,
162 "error": r.get::<_, Option<String>>(8)?,
163 }))
164 },
165 );
166 let items = list_items(&conn, id);
167 match run {
168 Ok(r) => Json(json!({"run": r, "items": items})),
169 Err(e) => Json(json!({"error": e.to_string()})),
170 }
171}
172
173fn list_items(conn: &rusqlite::Connection, run_id: i64) -> Vec<Value> {
174 let mut stmt = match conn.prepare(
175 "SELECT id, item_type, source_path, dest_path, status, \
176 bytes_transferred, duration_ms, error_message \
177 FROM provision_items WHERE run_id = ?1 ORDER BY id",
178 ) {
179 Ok(s) => s,
180 Err(_) => return vec![],
181 };
182 stmt.query_map([run_id], |r| {
183 Ok(json!({
184 "id": r.get::<_, i64>(0)?,
185 "item_type": r.get::<_, String>(1)?,
186 "source": r.get::<_, String>(2)?,
187 "dest": r.get::<_, String>(3)?,
188 "status": r.get::<_, String>(4)?,
189 "bytes": r.get::<_, u64>(5)?,
190 "duration_ms": r.get::<_, u64>(6)?,
191 "error": r.get::<_, Option<String>>(7)?,
192 }))
193 })
194 .map(|rows| rows.flatten().collect())
195 .unwrap_or_default()
196}