1use crate::handlers::{
6 handle_context, handle_discover, handle_impact, handle_info, handle_node_get, handle_search,
7 SharedGraph,
8};
9use crate::protocol::{
10 ContextParams, DiscoverParams, ImpactParams, NodeGetParams, Request, Response, SearchParams,
11};
12use arbor_graph::ArborGraph;
13use futures_util::{SinkExt, StreamExt};
14use std::net::SocketAddr;
15use std::sync::Arc;
16use tokio::net::{TcpListener, TcpStream};
17use tokio::sync::RwLock;
18use tokio_tungstenite::{accept_async, tungstenite::Message};
19use tracing::{debug, error, info, warn};
20
21pub struct ServerConfig {
23 pub addr: SocketAddr,
25}
26
27impl Default for ServerConfig {
28 fn default() -> Self {
29 Self {
30 addr: "127.0.0.1:7432".parse().unwrap(),
31 }
32 }
33}
34
35pub struct ArborServer {
37 config: ServerConfig,
38 graph: SharedGraph,
39}
40
41impl ArborServer {
42 pub fn new(graph: ArborGraph, config: ServerConfig) -> Self {
44 Self {
45 config,
46 graph: Arc::new(RwLock::new(graph)),
47 }
48 }
49
50 pub fn graph(&self) -> SharedGraph {
52 self.graph.clone()
53 }
54
55 pub async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
57 let listener = TcpListener::bind(&self.config.addr).await?;
58 info!("Arbor server listening on {}", self.config.addr);
59
60 loop {
61 match listener.accept().await {
62 Ok((stream, addr)) => {
63 debug!("New connection from {}", addr);
64 let graph = self.graph.clone();
65 tokio::spawn(async move {
66 if let Err(e) = handle_connection(stream, addr, graph).await {
67 error!("Connection error from {}: {}", addr, e);
68 }
69 });
70 }
71 Err(e) => {
72 error!("Accept error: {}", e);
73 }
74 }
75 }
76 }
77}
78
79async fn handle_connection(
81 stream: TcpStream,
82 addr: SocketAddr,
83 graph: SharedGraph,
84) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
85 let ws_stream = accept_async(stream).await?;
86 info!("WebSocket connection established with {}", addr);
87
88 let (mut write, mut read) = ws_stream.split();
89
90 while let Some(msg) = read.next().await {
91 let msg = match msg {
92 Ok(m) => m,
93 Err(e) => {
94 warn!("Message error from {}: {}", addr, e);
95 break;
96 }
97 };
98
99 if msg.is_close() {
100 debug!("Client {} disconnected", addr);
101 break;
102 }
103
104 if msg.is_ping() {
105 write.send(Message::Pong(msg.into_data())).await?;
106 continue;
107 }
108
109 if msg.is_text() {
110 let text = msg.to_text().unwrap_or("");
111 let response = process_message(text, graph.clone()).await;
112 let json = serde_json::to_string(&response)?;
113 write.send(Message::Text(json)).await?;
114 }
115 }
116
117 info!("Connection closed: {}", addr);
118 Ok(())
119}
120
121async fn process_message(text: &str, graph: SharedGraph) -> Response {
123 let request: Request = match serde_json::from_str(text) {
125 Ok(r) => r,
126 Err(_) => return Response::parse_error(),
127 };
128
129 let id = request.id.clone();
130 let method = request.method.as_str();
131
132 debug!("Processing method: {}", method);
133
134 match method {
136 "graph.info" => handle_info(graph, id).await,
137
138 "discover" => match serde_json::from_value::<DiscoverParams>(request.params) {
139 Ok(params) => handle_discover(graph, id, params).await,
140 Err(e) => Response::invalid_params(id, e.to_string()),
141 },
142
143 "impact" => match serde_json::from_value::<ImpactParams>(request.params) {
144 Ok(params) => handle_impact(graph, id, params).await,
145 Err(e) => Response::invalid_params(id, e.to_string()),
146 },
147
148 "context" => match serde_json::from_value::<ContextParams>(request.params) {
149 Ok(params) => handle_context(graph, id, params).await,
150 Err(e) => Response::invalid_params(id, e.to_string()),
151 },
152
153 "search" => match serde_json::from_value::<SearchParams>(request.params) {
154 Ok(params) => handle_search(graph, id, params).await,
155 Err(e) => Response::invalid_params(id, e.to_string()),
156 },
157
158 "node.get" => match serde_json::from_value::<NodeGetParams>(request.params) {
159 Ok(params) => handle_node_get(graph, id, params).await,
160 Err(e) => Response::invalid_params(id, e.to_string()),
161 },
162
163 _ => Response::method_not_found(id, method),
164 }
165}