Skip to main content

construct/gateway/
canvas.rs

1//! Live Canvas gateway routes — REST + WebSocket for real-time canvas updates.
2//!
3//! - `GET  /api/canvas/:id` — get current canvas content (JSON)
4//! - `POST /api/canvas/:id` — push content programmatically
5//! - `GET  /api/canvas`     — list all active canvases
6//! - `WS   /ws/canvas/:id`  — real-time canvas updates via WebSocket
7
8use super::AppState;
9use super::api::require_auth;
10use axum::{
11    extract::{
12        ConnectInfo, Path, State, WebSocketUpgrade,
13        ws::{Message, WebSocket},
14    },
15    http::{HeaderMap, StatusCode, header},
16    response::{IntoResponse, Json},
17};
18use futures_util::{SinkExt, StreamExt};
19use serde::Deserialize;
20
21/// POST /api/canvas/:id request body.
22#[derive(Deserialize)]
23pub struct CanvasPostBody {
24    pub content_type: Option<String>,
25    pub content: String,
26}
27
28/// GET /api/canvas — list all active canvases.
29pub async fn handle_canvas_list(
30    State(state): State<AppState>,
31    headers: HeaderMap,
32) -> impl IntoResponse {
33    if let Err(e) = require_auth(&state, &headers) {
34        return e.into_response();
35    }
36
37    let ids = state.canvas_store.list();
38    Json(serde_json::json!({ "canvases": ids })).into_response()
39}
40
41/// GET /api/canvas/:id — get current canvas content.
42pub async fn handle_canvas_get(
43    State(state): State<AppState>,
44    headers: HeaderMap,
45    Path(id): Path<String>,
46) -> impl IntoResponse {
47    if let Err(e) = require_auth(&state, &headers) {
48        return e.into_response();
49    }
50
51    match state.canvas_store.snapshot(&id) {
52        Some(frame) => Json(serde_json::json!({
53            "canvas_id": id,
54            "frame": frame,
55        }))
56        .into_response(),
57        None => (
58            StatusCode::NOT_FOUND,
59            Json(serde_json::json!({ "error": format!("Canvas '{}' not found", id) })),
60        )
61            .into_response(),
62    }
63}
64
65/// GET /api/canvas/:id/history — get canvas frame history.
66pub async fn handle_canvas_history(
67    State(state): State<AppState>,
68    headers: HeaderMap,
69    Path(id): Path<String>,
70) -> impl IntoResponse {
71    if let Err(e) = require_auth(&state, &headers) {
72        return e.into_response();
73    }
74
75    let history = state.canvas_store.history(&id);
76    Json(serde_json::json!({
77        "canvas_id": id,
78        "frames": history,
79    }))
80    .into_response()
81}
82
83/// POST /api/canvas/:id — push content to a canvas.
84pub async fn handle_canvas_post(
85    State(state): State<AppState>,
86    ConnectInfo(addr): ConnectInfo<std::net::SocketAddr>,
87    headers: HeaderMap,
88    Path(id): Path<String>,
89    Json(body): Json<CanvasPostBody>,
90) -> impl IntoResponse {
91    // Allow localhost (operator) without auth; external callers must authenticate.
92    if !addr.ip().is_loopback() {
93        if let Err(e) = require_auth(&state, &headers) {
94            return e.into_response();
95        }
96    }
97
98    let content_type = body.content_type.as_deref().unwrap_or("html");
99
100    // Validate content_type against allowed set (prevent injecting "eval" frames via REST).
101    if !crate::tools::canvas::ALLOWED_CONTENT_TYPES.contains(&content_type) {
102        return (
103            StatusCode::BAD_REQUEST,
104            Json(serde_json::json!({
105                "error": format!(
106                    "Invalid content_type '{}'. Allowed: {:?}",
107                    content_type,
108                    crate::tools::canvas::ALLOWED_CONTENT_TYPES
109                )
110            })),
111        )
112            .into_response();
113    }
114
115    // Enforce content size limit (same as tool-side validation).
116    if body.content.len() > crate::tools::canvas::MAX_CONTENT_SIZE {
117        return (
118            StatusCode::PAYLOAD_TOO_LARGE,
119            Json(serde_json::json!({
120                "error": format!(
121                    "Content exceeds maximum size of {} bytes",
122                    crate::tools::canvas::MAX_CONTENT_SIZE
123                )
124            })),
125        )
126            .into_response();
127    }
128
129    match state.canvas_store.render(&id, content_type, &body.content) {
130        Some(frame) => (
131            StatusCode::CREATED,
132            Json(serde_json::json!({
133                "canvas_id": id,
134                "frame": frame,
135            })),
136        )
137            .into_response(),
138        None => (
139            StatusCode::TOO_MANY_REQUESTS,
140            Json(serde_json::json!({
141                "error": "Maximum canvas count reached. Clear unused canvases first."
142            })),
143        )
144            .into_response(),
145    }
146}
147
148/// DELETE /api/canvas/:id — clear a canvas.
149pub async fn handle_canvas_clear(
150    State(state): State<AppState>,
151    ConnectInfo(addr): ConnectInfo<std::net::SocketAddr>,
152    headers: HeaderMap,
153    Path(id): Path<String>,
154) -> impl IntoResponse {
155    if !addr.ip().is_loopback() {
156        if let Err(e) = require_auth(&state, &headers) {
157            return e.into_response();
158        }
159    }
160
161    state.canvas_store.clear(&id);
162    Json(serde_json::json!({
163        "canvas_id": id,
164        "status": "cleared",
165    }))
166    .into_response()
167}
168
169/// Query parameters for canvas WebSocket.
170#[derive(Deserialize)]
171pub struct CanvasWsQuery {
172    pub token: Option<String>,
173}
174
175/// WS /ws/canvas/:id — real-time canvas updates.
176pub async fn handle_ws_canvas(
177    State(state): State<AppState>,
178    Path(id): Path<String>,
179    axum::extract::Query(params): axum::extract::Query<CanvasWsQuery>,
180    headers: HeaderMap,
181    ws: WebSocketUpgrade,
182) -> impl IntoResponse {
183    // Auth check — same precedence as ws::handle_ws_chat:
184    // 1. Authorization header, 2. Sec-WebSocket-Protocol bearer, 3. ?token= query param
185    if state.pairing.require_pairing() {
186        let token = headers
187            .get(header::AUTHORIZATION)
188            .and_then(|v| v.to_str().ok())
189            .and_then(|auth| auth.strip_prefix("Bearer "))
190            .or_else(|| {
191                headers
192                    .get("sec-websocket-protocol")
193                    .and_then(|v| v.to_str().ok())
194                    .and_then(|protos| {
195                        protos
196                            .split(',')
197                            .map(|p| p.trim())
198                            .find_map(|p| p.strip_prefix("bearer."))
199                    })
200            })
201            .or(params.token.as_deref())
202            .unwrap_or("");
203
204        if !state.pairing.is_authenticated(token) {
205            return (
206                StatusCode::UNAUTHORIZED,
207                "Unauthorized — provide Authorization header or Sec-WebSocket-Protocol bearer",
208            )
209                .into_response();
210        }
211    }
212
213    // Echo Sec-WebSocket-Protocol if the client requests our sub-protocol.
214    // Without this, browsers reject the connection when protocols are sent.
215    let ws = if headers
216        .get("sec-websocket-protocol")
217        .and_then(|v| v.to_str().ok())
218        .map_or(false, |protos| {
219            protos.split(',').any(|p| p.trim() == "construct.v1")
220        }) {
221        ws.protocols(["construct.v1"])
222    } else {
223        ws
224    };
225
226    ws.on_upgrade(move |socket| handle_canvas_socket(socket, state, id))
227        .into_response()
228}
229
230async fn handle_canvas_socket(socket: WebSocket, state: AppState, canvas_id: String) {
231    let (mut sender, mut receiver) = socket.split();
232
233    // Subscribe to canvas updates
234    let mut rx = match state.canvas_store.subscribe(&canvas_id) {
235        Some(rx) => rx,
236        None => {
237            let msg = serde_json::json!({
238                "type": "error",
239                "error": "Maximum canvas count reached",
240            });
241            let _ = sender.send(Message::Text(msg.to_string().into())).await;
242            return;
243        }
244    };
245
246    // Send current state immediately if available
247    if let Some(frame) = state.canvas_store.snapshot(&canvas_id) {
248        let msg = serde_json::json!({
249            "type": "frame",
250            "canvas_id": canvas_id,
251            "frame": frame,
252        });
253        let _ = sender.send(Message::Text(msg.to_string().into())).await;
254    }
255
256    // Send a connected acknowledgement
257    let ack = serde_json::json!({
258        "type": "connected",
259        "canvas_id": canvas_id,
260    });
261    let _ = sender.send(Message::Text(ack.to_string().into())).await;
262
263    // Spawn a task that forwards broadcast updates to the WebSocket
264    let canvas_id_clone = canvas_id.clone();
265    let send_task = tokio::spawn(async move {
266        loop {
267            match rx.recv().await {
268                Ok(frame) => {
269                    let msg = serde_json::json!({
270                        "type": "frame",
271                        "canvas_id": canvas_id_clone,
272                        "frame": frame,
273                    });
274                    if sender
275                        .send(Message::Text(msg.to_string().into()))
276                        .await
277                        .is_err()
278                    {
279                        break;
280                    }
281                }
282                Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
283                    // Client fell behind — notify and continue rather than disconnecting.
284                    let msg = serde_json::json!({
285                        "type": "lagged",
286                        "canvas_id": canvas_id_clone,
287                        "missed_frames": n,
288                    });
289                    let _ = sender.send(Message::Text(msg.to_string().into())).await;
290                }
291                Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
292            }
293        }
294    });
295
296    // Read loop: we mostly ignore incoming messages but handle close/ping
297    while let Some(msg) = receiver.next().await {
298        match msg {
299            Ok(Message::Close(_)) | Err(_) => break,
300            _ => {} // Ignore all other messages (pings are handled by axum)
301        }
302    }
303
304    // Abort the send task when the connection is closed
305    send_task.abort();
306}