infinite_db/infinitedb_server/
runtime.rs1use std::io;
4use std::net::SocketAddr;
5use std::sync::Arc;
6
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::{TcpListener, TcpStream};
9use tokio::sync::Semaphore;
10
11use crate::infinitedb_core::branch::BranchId;
12use crate::infinitedb_core::snapshot::SnapshotId;
13use crate::InfiniteDb;
14
15use super::api::{handle_request, Request, Response};
16use super::session::{AccessLevel, Session, SpaceGrant};
17#[derive(Debug, Clone)]
19pub struct ServerConfig {
20 pub max_connections: usize,
21 pub default_branch: BranchId,
22}
23
24impl Default for ServerConfig {
25 fn default() -> Self {
26 Self {
27 max_connections: 128,
28 default_branch: BranchId::MAIN,
29 }
30 }
31}
32
33pub struct Server {
35 listener: TcpListener,
36 db: Arc<InfiniteDb>,
37 config: ServerConfig,
38 grants: Vec<SpaceGrant>,
39 limiter: Arc<Semaphore>,
40}
41
42impl Server {
43 pub async fn bind(
45 addr: SocketAddr,
46 db: Arc<InfiniteDb>,
47 config: ServerConfig,
48 grants: Vec<SpaceGrant>,
49 ) -> io::Result<Self> {
50 let listener = TcpListener::bind(addr).await?;
51 let limiter = Arc::new(Semaphore::new(config.max_connections));
52 Ok(Self {
53 listener,
54 db,
55 config,
56 grants,
57 limiter,
58 })
59 }
60
61 pub fn local_addr(&self) -> io::Result<SocketAddr> {
62 self.listener.local_addr()
63 }
64
65 pub async fn run(self) -> io::Result<()> {
67 loop {
68 let (stream, _) = self.listener.accept().await?;
69 let permit = Arc::clone(&self.limiter)
70 .acquire_owned()
71 .await
72 .map_err(|e| io::Error::new(io::ErrorKind::Other, e.to_string()))?;
73 let db = Arc::clone(&self.db);
74 let grants = self.grants.clone();
75 let branch = self.config.default_branch;
76 tokio::spawn(async move {
77 let _permit = permit;
78 let _ = serve_connection(stream, db, branch, grants).await;
79 });
80 }
81 }
82}
83
84async fn serve_connection(
85 mut stream: TcpStream,
86 db: Arc<InfiniteDb>,
87 branch: BranchId,
88 grants: Vec<SpaceGrant>,
89) -> io::Result<()> {
90 let pinned = db
91 .branch_head(branch)
92 .unwrap_or(SnapshotId(0));
93 let opened_at = db.revision();
94 let session = Session::open_at_revision(branch, pinned, opened_at, grants);
95
96 loop {
97 let request: Request = read_frame_async(&mut stream).await?;
98 let response = handle_request(&db, &session, request);
99 write_frame_async(&mut stream, &response).await?;
100 if matches!(response, Response::Error(_)) {
101 }
103 }
104}
105
106async fn read_frame_async<T: bincode::Decode<()> + Send + 'static>(
107 stream: &mut TcpStream,
108) -> io::Result<T> {
109 let mut len_buf = [0u8; 8];
110 stream.read_exact(&mut len_buf).await?;
111 let len = u64::from_le_bytes(len_buf) as usize;
112 if len > 64 * 1024 * 1024 {
113 return Err(io::Error::new(
114 io::ErrorKind::InvalidData,
115 "frame too large",
116 ));
117 }
118 let mut payload = vec![0u8; len];
119 stream.read_exact(&mut payload).await?;
120 let (msg, _) = bincode::decode_from_slice::<T, _>(&payload, bincode::config::standard())
121 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
122 Ok(msg)
123}
124
125async fn write_frame_async<T: bincode::Encode + Send + Sync>(
126 stream: &mut TcpStream,
127 msg: &T,
128) -> io::Result<()> {
129 let payload = bincode::encode_to_vec(msg, bincode::config::standard())
130 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
131 let len = payload.len() as u64;
132 stream.write_all(&len.to_le_bytes()).await?;
133 stream.write_all(&payload).await?;
134 stream.flush().await
135}
136
137pub async fn client_roundtrip(
139 addr: SocketAddr,
140 request: Request,
141) -> io::Result<Response> {
142 let mut stream = TcpStream::connect(addr).await?;
143 write_frame_async(&mut stream, &request).await?;
144 read_frame_async(&mut stream).await
145}
146
147pub fn admin_grants(space_ids: &[u64]) -> Vec<SpaceGrant> {
149 space_ids
150 .iter()
151 .map(|id| SpaceGrant {
152 space: crate::infinitedb_core::address::SpaceId(*id),
153 level: AccessLevel::Admin,
154 })
155 .collect()
156}