1use anyhow::{bail, Context, Result};
4use std::cell::RefCell;
5use std::collections::HashMap;
6use std::io::{Read, Write};
7use std::os::unix::io::FromRawFd;
8use std::os::unix::net::{UnixListener, UnixStream};
9use std::path::Path;
10use std::sync::Arc;
11use std::time::SystemTime;
12
13pub use cell_macros::{call_as, service_schema};
14pub use rkyv;
15
16thread_local! {
19 static CONNECTION_POOL: RefCell<HashMap<String, UnixStream>> = RefCell::new(HashMap::new());
20}
21
22pub fn invoke_rpc(_service_name: &str, socket_path: &str, payload: &[u8]) -> Result<Vec<u8>> {
23 let response = CONNECTION_POOL.with(|pool_cell| {
25 let mut pool = pool_cell.borrow_mut();
26 if let Some(mut stream) = pool.remove(socket_path) {
27 if send_request(&mut stream, payload).is_ok() {
28 if let Ok(resp) = read_response(&mut stream) {
29 pool.insert(socket_path.to_string(), stream);
30 return Some(Ok(resp));
31 }
32 }
33 }
34 None
35 });
36 if let Some(res) = response {
37 return res;
38 }
39 let mut stream = connect_new(socket_path)?;
40 send_request(&mut stream, payload)?;
41 let resp = read_response(&mut stream)?;
42 CONNECTION_POOL.with(|pool_cell| {
43 pool_cell
44 .borrow_mut()
45 .insert(socket_path.to_string(), stream);
46 });
47 Ok(resp)
48}
49
50pub struct CellClient {
53 stream: UnixStream,
54 wbuf: Vec<u8>,
55 batch_limit: usize,
56 pending_count: usize,
57}
58
59impl CellClient {
60 pub fn connect(socket_path: &str) -> Result<Self> {
61 Self::connect_with_batch(socket_path, 1) }
63
64 pub fn connect_to_service(service_name: &str) -> Result<Self> {
65 Self::connect(&resolve_socket_path(service_name))
66 }
67
68 pub fn connect_with_batch(socket_path: &str, batch_size: usize) -> Result<Self> {
69 let stream = connect_new(socket_path)?;
70 Ok(Self {
71 stream,
72 wbuf: Vec::with_capacity(4096 * batch_size), batch_limit: batch_size,
74 pending_count: 0,
75 })
76 }
77
78 pub fn call(&mut self, payload: &[u8]) -> Result<Vec<u8>> {
79 self.wbuf
81 .extend_from_slice(&(payload.len() as u32).to_be_bytes());
82 self.wbuf.extend_from_slice(payload);
83 self.pending_count += 1;
84
85 if self.pending_count >= self.batch_limit {
87 self.flush_writes()?;
88 }
89
90 if self.batch_limit == 1 {
139 self.stream.write_all(&self.wbuf)?;
141 self.wbuf.clear();
142 self.pending_count = 0;
143 read_response(&mut self.stream).map_err(|e| e.into())
144 } else {
145 Ok(Vec::new())
149 }
150 }
151
152 pub fn queue_request(&mut self, payload: &[u8]) -> Result<bool> {
155 self.wbuf
156 .extend_from_slice(&(payload.len() as u32).to_be_bytes());
157 self.wbuf.extend_from_slice(payload);
158 self.pending_count += 1;
159
160 if self.pending_count >= self.batch_limit {
161 self.stream.write_all(&self.wbuf)?;
162 self.wbuf.clear();
163 self.pending_count = 0;
164 return Ok(true); }
166 Ok(false)
167 }
168
169 pub fn read_n_responses(&mut self, n: usize) -> Result<()> {
171 for _ in 0..n {
172 let _ = read_response(&mut self.stream)?;
173 }
174 Ok(())
175 }
176
177 pub fn flush_writes(&mut self) -> Result<()> {
178 if !self.wbuf.is_empty() {
179 self.stream.write_all(&self.wbuf)?;
180 self.wbuf.clear();
181 self.pending_count = 0;
182 }
183 Ok(())
184 }
185}
186
187pub fn resolve_socket_path(service_name: &str) -> String {
189 let env_key = format!("CELL_DEP_{}_SOCK", service_name.to_uppercase());
190 std::env::var(&env_key).unwrap_or_else(|_| format!("../{}/run/cell.sock", service_name))
191}
192
193fn connect_new(path: &str) -> Result<UnixStream> {
194 let stream = UnixStream::connect(path).with_context(|| format!("Connect to {}", path))?;
195 stream
196 .set_nonblocking(false)
197 .context("Failed to set blocking mode")?;
198 stream.set_read_timeout(Some(std::time::Duration::from_secs(60)))?;
199 Ok(stream)
200}
201
202fn send_request(stream: &mut UnixStream, payload: &[u8]) -> std::io::Result<()> {
203 stream.write_all(&(payload.len() as u32).to_be_bytes())?;
204 stream.write_all(payload)?;
205 stream.flush()
206}
207
208fn read_response(stream: &mut UnixStream) -> std::io::Result<Vec<u8>> {
209 let mut len_buf = [0u8; 4];
210 stream.read_exact(&mut len_buf)?;
211 let len = u32::from_be_bytes(len_buf) as usize;
212 let mut buf = vec![0u8; len];
213 stream.read_exact(&mut buf)?;
214 Ok(buf)
215}
216
217pub fn run_service_with_schema<F>(service_name: &str, schema_json: &str, handler: F) -> Result<()>
219where
220 F: Fn(&[u8]) -> Result<Vec<u8>> + Send + Sync + 'static,
221{
222 let listener = if let Ok(fd_str) = std::env::var("CELL_SOCKET_FD") {
223 let fd: i32 = fd_str.parse().context("CELL_SOCKET_FD invalid")?;
224 unsafe { UnixListener::from_raw_fd(fd) }
225 } else {
226 let path_str =
227 std::env::var("CELL_SOCKET_PATH").unwrap_or_else(|_| "run/cell.sock".to_string());
228 let path = Path::new(&path_str);
229 if let Some(p) = path.parent() {
230 std::fs::create_dir_all(p)?;
231 }
232 if path.exists() {
233 std::fs::remove_file(path)?;
234 }
235 UnixListener::bind(path)?
236 };
237
238 listener
239 .set_nonblocking(false)
240 .context("Set listener blocking failed")?;
241 eprintln!(
242 "{} 🚀 Service '{}' ready",
243 humantime::format_rfc3339(SystemTime::now()),
244 service_name
245 );
246 let handler_arc = Arc::new(handler);
247 let schema_bytes = schema_json.as_bytes().to_vec();
248
249 for stream in listener.incoming() {
250 match stream {
251 Ok(mut s) => {
252 let _ = s.set_nonblocking(false);
253 let h = handler_arc.clone();
254 let schema = schema_bytes.clone();
255 std::thread::spawn(move || {
256 if let Err(e) = handle_client_loop(&mut s, &schema, &*h) {
257 if e.to_string() != "Client disconnected" {
258 eprintln!("Handler error: {}", e);
259 }
260 }
261 });
262 }
263 Err(e) => eprintln!("Accept error: {}", e),
264 }
265 }
266 Ok(())
267}
268
269fn handle_client_loop(
270 stream: &mut UnixStream,
271 schema_bytes: &[u8],
272 handler: &dyn Fn(&[u8]) -> Result<Vec<u8>>,
273) -> Result<()> {
274 loop {
280 let mut len_buf = [0u8; 4];
281 match stream.read_exact(&mut len_buf) {
282 Ok(_) => {}
283 Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
284 return Err(anyhow::anyhow!("Client disconnected"))
285 }
286 Err(e) => return Err(e.into()),
287 }
288 let len = u32::from_be_bytes(len_buf) as usize;
289 if len > 256 * 1024 * 1024 {
290 bail!("Message too large");
291 }
292 let mut msg_buf = vec![0u8; len];
293 stream.read_exact(&mut msg_buf)?;
294
295 if &msg_buf == b"__SCHEMA__" {
296 stream.write_all(&(schema_bytes.len() as u32).to_be_bytes())?;
297 stream.write_all(schema_bytes)?;
298 stream.flush()?;
299 continue;
300 }
301
302 let response_bytes = handler(&msg_buf)?;
303 stream.write_all(&(response_bytes.len() as u32).to_be_bytes())?;
304 stream.write_all(&response_bytes)?;
305 stream.flush()?;
306 }
307}