arbor_server/
server.rs

1//! WebSocket server implementation.
2//!
3//! Handles client connections and routes messages to handlers.
4
5use crate::handlers::{
6    handle_context, handle_discover, handle_impact, handle_info, handle_node_get, handle_search,
7    SharedGraph,
8};
9use crate::protocol::{
10    ContextParams, DiscoverParams, ImpactParams, NodeGetParams, Request, Response, SearchParams,
11};
12use arbor_graph::ArborGraph;
13use futures_util::{SinkExt, StreamExt};
14use std::net::SocketAddr;
15use std::sync::Arc;
16use tokio::net::{TcpListener, TcpStream};
17use tokio::sync::RwLock;
18use tokio_tungstenite::{accept_async, tungstenite::Message};
19use tracing::{debug, error, info, warn};
20
21/// Server configuration.
22pub struct ServerConfig {
23    /// Address to bind to.
24    pub addr: SocketAddr,
25}
26
27impl Default for ServerConfig {
28    fn default() -> Self {
29        Self {
30            addr: "127.0.0.1:7432".parse().unwrap(),
31        }
32    }
33}
34
35/// The Arbor WebSocket server.
36pub struct ArborServer {
37    config: ServerConfig,
38    graph: SharedGraph,
39}
40
41impl ArborServer {
42    /// Creates a new server with the given graph.
43    pub fn new(graph: ArborGraph, config: ServerConfig) -> Self {
44        Self {
45            config,
46            graph: Arc::new(RwLock::new(graph)),
47        }
48    }
49
50    /// Returns a handle to the shared graph for updates.
51    pub fn graph(&self) -> SharedGraph {
52        self.graph.clone()
53    }
54
55    /// Runs the server, accepting connections forever.
56    pub async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
57        let listener = TcpListener::bind(&self.config.addr).await?;
58        info!("Arbor server listening on {}", self.config.addr);
59
60        loop {
61            match listener.accept().await {
62                Ok((stream, addr)) => {
63                    debug!("New connection from {}", addr);
64                    let graph = self.graph.clone();
65                    tokio::spawn(async move {
66                        if let Err(e) = handle_connection(stream, addr, graph).await {
67                            error!("Connection error from {}: {}", addr, e);
68                        }
69                    });
70                }
71                Err(e) => {
72                    error!("Accept error: {}", e);
73                }
74            }
75        }
76    }
77}
78
79/// Handles a single WebSocket connection.
80async fn handle_connection(
81    stream: TcpStream,
82    addr: SocketAddr,
83    graph: SharedGraph,
84) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
85    let ws_stream = accept_async(stream).await?;
86    info!("WebSocket connection established with {}", addr);
87
88    let (mut write, mut read) = ws_stream.split();
89
90    while let Some(msg) = read.next().await {
91        let msg = match msg {
92            Ok(m) => m,
93            Err(e) => {
94                warn!("Message error from {}: {}", addr, e);
95                break;
96            }
97        };
98
99        if msg.is_close() {
100            debug!("Client {} disconnected", addr);
101            break;
102        }
103
104        if msg.is_ping() {
105            write.send(Message::Pong(msg.into_data())).await?;
106            continue;
107        }
108
109        if msg.is_text() {
110            let text = msg.to_text().unwrap_or("");
111            let response = process_message(text, graph.clone()).await;
112            let json = serde_json::to_string(&response)?;
113            write.send(Message::Text(json)).await?;
114        }
115    }
116
117    info!("Connection closed: {}", addr);
118    Ok(())
119}
120
121/// Processes a JSON-RPC message and returns a response.
122async fn process_message(text: &str, graph: SharedGraph) -> Response {
123    // Parse the request
124    let request: Request = match serde_json::from_str(text) {
125        Ok(r) => r,
126        Err(_) => return Response::parse_error(),
127    };
128
129    let id = request.id.clone();
130    let method = request.method.as_str();
131
132    debug!("Processing method: {}", method);
133
134    // Route to handler
135    match method {
136        "graph.info" => handle_info(graph, id).await,
137
138        "discover" => match serde_json::from_value::<DiscoverParams>(request.params) {
139            Ok(params) => handle_discover(graph, id, params).await,
140            Err(e) => Response::invalid_params(id, e.to_string()),
141        },
142
143        "impact" => match serde_json::from_value::<ImpactParams>(request.params) {
144            Ok(params) => handle_impact(graph, id, params).await,
145            Err(e) => Response::invalid_params(id, e.to_string()),
146        },
147
148        "context" => match serde_json::from_value::<ContextParams>(request.params) {
149            Ok(params) => handle_context(graph, id, params).await,
150            Err(e) => Response::invalid_params(id, e.to_string()),
151        },
152
153        "search" => match serde_json::from_value::<SearchParams>(request.params) {
154            Ok(params) => handle_search(graph, id, params).await,
155            Err(e) => Response::invalid_params(id, e.to_string()),
156        },
157
158        "node.get" => match serde_json::from_value::<NodeGetParams>(request.params) {
159            Ok(params) => handle_node_get(graph, id, params).await,
160            Err(e) => Response::invalid_params(id, e.to_string()),
161        },
162
163        _ => Response::method_not_found(id, method),
164    }
165}