Skip to main content

st/web_dashboard/
server.rs

1//! Axum HTTP server for the web dashboard
2
3use super::{api, assets, collab, state_sync, voice, websocket, DashboardState, SharedState};
4use crate::in_memory_logger::InMemoryLogStore;
5use anyhow::Result;
6use axum::{
7    body::Body,
8    extract::ConnectInfo,
9    http::{header, Request, StatusCode},
10    middleware::Next,
11    response::{Html, IntoResponse, Response},
12    routing::{get, post},
13    Router,
14};
15use tower_http::cors::CorsLayer;
16use ipnet::IpNet;
17use std::net::{IpAddr, SocketAddr};
18use std::path::PathBuf;
19use std::sync::Arc;
20use tokio::sync::RwLock;
21
22/// Allowed networks for connection filtering
23#[derive(Clone)]
24struct AllowedNetworks {
25    networks: Vec<IpNet>,
26    allow_all: bool,
27}
28
29impl AllowedNetworks {
30    fn new(cidrs: Vec<String>) -> Self {
31        if cidrs.is_empty() {
32            // Default: localhost only
33            return Self {
34                networks: vec!["127.0.0.0/8".parse().unwrap(), "::1/128".parse().unwrap()],
35                allow_all: false,
36            };
37        }
38
39        let mut networks = Vec::new();
40        let mut allow_all = false;
41
42        for cidr in cidrs {
43            if cidr == "0.0.0.0/0" || cidr == "::/0" || cidr == "any" {
44                allow_all = true;
45                break;
46            }
47            if let Ok(net) = cidr.parse::<IpNet>() {
48                networks.push(net);
49            } else {
50                eprintln!("Warning: Invalid CIDR '{}', ignoring", cidr);
51            }
52        }
53
54        Self {
55            networks,
56            allow_all,
57        }
58    }
59
60    fn is_allowed(&self, ip: IpAddr) -> bool {
61        if self.allow_all {
62            return true;
63        }
64        // Always allow localhost
65        if ip.is_loopback() {
66            return true;
67        }
68        self.networks.iter().any(|net| net.contains(&ip))
69    }
70}
71
72/// Middleware to check if client IP is allowed
73async fn check_allowed_network(
74    ConnectInfo(addr): ConnectInfo<SocketAddr>,
75    req: Request<Body>,
76    next: Next,
77) -> Result<Response, StatusCode> {
78    let allowed = req
79        .extensions()
80        .get::<AllowedNetworks>()
81        .map(|nets| nets.is_allowed(addr.ip()))
82        .unwrap_or(true);
83
84    if allowed {
85        Ok(next.run(req).await)
86    } else {
87        eprintln!("Rejected connection from {}", addr.ip());
88        Err(StatusCode::FORBIDDEN)
89    }
90}
91
92/// Start the web dashboard server
93pub async fn start_server(
94    port: u16,
95    open_browser: bool,
96    allow_networks: Vec<String>,
97    log_store: InMemoryLogStore,
98) -> Result<()> {
99    let cwd = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
100    let state: SharedState = Arc::new(RwLock::new(DashboardState::new(cwd, log_store)));
101
102    let has_explicit_networks = !allow_networks.is_empty();
103    let allowed = AllowedNetworks::new(allow_networks.clone());
104    let bind_all = has_explicit_networks || allowed.allow_all;
105
106    let app = Router::new()
107        // Static assets
108        .route("/", get(serve_index))
109        .route("/style.css", get(serve_css))
110        .route("/app.js", get(serve_js))
111        .route("/xterm.min.js", get(serve_xterm_js))
112        .route("/xterm.css", get(serve_xterm_css))
113        .route("/xterm-addon-fit.min.js", get(serve_xterm_fit_js))
114        .route("/marked.min.js", get(serve_marked_js))
115        // API endpoints
116        .route("/api/health", get(api::health))
117        .route("/api/files", get(api::list_files))
118        .route("/api/file", get(api::read_file))
119        .route("/api/file", post(api::write_file))
120        .route("/api/tree", get(api::get_tree))
121        .route("/api/markdown", get(api::render_markdown))
122        .route("/api/logs", get(api::get_logs))
123        // Config endpoints
124        .route(
125            "/api/config/layout",
126            get(api::get_layout_config).post(api::save_layout_config),
127        )
128        .route(
129            "/api/config/theme",
130            get(api::get_theme_config).post(api::save_theme_config),
131        )
132        // Prompt endpoints
133        .route("/api/prompt", get(api::get_active_prompts).post(api::ask_prompt))
134        .route("/api/prompt/:prompt_id/answer", post(api::answer_prompt))
135        // WebSocket endpoints
136        .route("/ws/terminal", get(websocket::terminal_handler))
137        .route("/ws/state", get(state_sync::state_handler))
138        .route("/ws/collab", get(collab::collab_handler))
139        // Voice API endpoints (stub handlers if feature disabled)
140        .route("/api/voice/transcribe", post(voice::transcribe))
141        .route("/api/voice/register", post(voice::register_speaker))
142        .route("/api/voice/speak", post(voice::speak))
143        .layer(axum::Extension(allowed.clone()))
144        .layer(CorsLayer::permissive())
145        .with_state(state);
146
147    // Bind to all interfaces if networks specified, localhost otherwise
148    let bind_addr: IpAddr = if bind_all {
149        [0, 0, 0, 0].into()
150    } else {
151        [127, 0, 0, 1].into()
152    };
153    let addr = SocketAddr::from((bind_addr, port));
154
155    println!("\x1b[32m");
156    println!("  ╔══════════════════════════════════════════════════════╗");
157    println!("  ║        Smart Tree Web Dashboard                      ║");
158    println!("  ╠══════════════════════════════════════════════════════╣");
159    if bind_all {
160        println!(
161            "  ║  http://0.0.0.0:{}                                ║",
162            port
163        );
164        println!("  ║                                                      ║");
165        println!("  ║  Allowed networks:                                   ║");
166        if allowed.allow_all {
167            println!("  ║    ANY (0.0.0.0/0)                                   ║");
168        } else {
169            for net in &allowed.networks {
170                if !net.addr().is_loopback() {
171                    println!("  ║    {}                                       ║", net);
172                }
173            }
174        }
175    } else {
176        println!(
177            "  ║  http://127.0.0.1:{}                              ║",
178            port
179        );
180        println!("  ║                                                      ║");
181        println!("  ║  Localhost only (use --allow for network access)     ║");
182    }
183    println!("  ║                                                      ║");
184    println!("  ║  Terminal: Real PTY with bash/zsh                    ║");
185    println!("  ║  Files: Browse and edit                              ║");
186    println!("  ║  Preview: Markdown rendering                         ║");
187    println!("  ╚══════════════════════════════════════════════════════╝");
188    println!("\x1b[0m");
189
190    if open_browser {
191        let url = format!("http://127.0.0.1:{}", port);
192        if let Err(e) = open::that(&url) {
193            eprintln!("Failed to open browser: {}", e);
194        }
195    }
196
197    let listener = tokio::net::TcpListener::bind(addr).await?;
198    axum::serve(
199        listener,
200        app.into_make_service_with_connect_info::<SocketAddr>(),
201    )
202    .await?;
203
204    Ok(())
205}
206
207// Static asset handlers
208async fn serve_index() -> Html<&'static str> {
209    Html(assets::INDEX_HTML)
210}
211
212async fn serve_css() -> impl IntoResponse {
213    ([(header::CONTENT_TYPE, "text/css")], assets::STYLE_CSS)
214}
215
216async fn serve_js() -> impl IntoResponse {
217    (
218        [(header::CONTENT_TYPE, "application/javascript")],
219        assets::APP_JS,
220    )
221}
222
223async fn serve_xterm_js() -> impl IntoResponse {
224    (
225        [(header::CONTENT_TYPE, "application/javascript")],
226        assets::XTERM_JS,
227    )
228}
229
230async fn serve_xterm_css() -> impl IntoResponse {
231    ([(header::CONTENT_TYPE, "text/css")], assets::XTERM_CSS)
232}
233
234async fn serve_xterm_fit_js() -> impl IntoResponse {
235    (
236        [(header::CONTENT_TYPE, "application/javascript")],
237        assets::XTERM_FIT_JS,
238    )
239}
240
241async fn serve_marked_js() -> impl IntoResponse {
242    (
243        [(header::CONTENT_TYPE, "application/javascript")],
244        assets::MARKED_JS,
245    )
246}