1use serde::Serialize;
2use serde::de::DeserializeOwned;
3use std::future::Future;
4use tokio::io::BufReader;
5use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
6use tokio::net::UnixStream;
7
8pub const DEFAULT_MAX_LINE_BYTES: usize = 1024 * 1024;
9
10pub async fn read_json_line_with_limit<R, T>(
11 reader: &mut R,
12 max_bytes: usize,
13) -> std::io::Result<Option<T>>
14where
15 R: AsyncBufRead + Unpin,
16 T: DeserializeOwned,
17{
18 let mut buf = Vec::new();
19 let n = reader.read_until(b'\n', &mut buf).await?;
20 if n == 0 {
21 return Ok(None);
22 }
23 if buf.len() > max_bytes {
24 return Err(std::io::Error::new(
25 std::io::ErrorKind::InvalidData,
26 format!(
27 "json line exceeds max length ({} > {})",
28 buf.len(),
29 max_bytes
30 ),
31 ));
32 }
33
34 let s = std::str::from_utf8(&buf)
35 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
36
37 serde_json::from_str::<T>(s)
38 .map(Some)
39 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
40}
41
42pub async fn read_json_line<R, T>(reader: &mut R) -> std::io::Result<Option<T>>
43where
44 R: AsyncBufRead + Unpin,
45 T: DeserializeOwned,
46{
47 read_json_line_with_limit(reader, DEFAULT_MAX_LINE_BYTES).await
48}
49
50pub async fn write_json_line<W, T>(writer: &mut W, value: &T) -> std::io::Result<()>
51where
52 W: AsyncWrite + Unpin,
53 T: Serialize,
54{
55 let json = serde_json::to_string(value)
56 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
57 writer.write_all(json.as_bytes()).await?;
58 writer.write_all(b"\n").await?;
59 Ok(())
60}
61
62pub async fn serve_jsonl_connection<Req, Resp, F, Fut, InvalidResp>(
63 stream: UnixStream,
64 handler: F,
65 invalid_response: InvalidResp,
66) -> std::io::Result<()>
67where
68 Req: DeserializeOwned,
69 Resp: Serialize,
70 F: Fn(Req) -> Fut,
71 Fut: Future<Output = Resp>,
72 InvalidResp: Fn(std::io::Error) -> Resp,
73{
74 let (reader, mut writer) = stream.into_split();
75 let mut reader = BufReader::new(reader);
76
77 loop {
78 let Some(req) = (match read_json_line::<_, Req>(&mut reader).await {
79 Ok(v) => v,
80 Err(e) if e.kind() == std::io::ErrorKind::InvalidData => {
81 let resp = invalid_response(e);
82 let _ = write_json_line(&mut writer, &resp).await;
83 continue;
84 }
85 Err(e) => return Err(e),
86 }) else {
87 break;
88 };
89
90 let resp = handler(req).await;
91 write_json_line(&mut writer, &resp).await?;
92 }
93
94 Ok(())
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use tokio::io::BufReader;
101
102 #[tokio::test]
103 async fn roundtrips_struct_over_jsonl() {
104 #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
105 struct Msg {
106 kind: String,
107 n: u64,
108 }
109
110 let (a, b) = tokio::io::duplex(1024);
111 let (mut ar, mut aw) = tokio::io::split(a);
112 let (mut br, mut bw) = tokio::io::split(b);
113 let mut ar = BufReader::new(&mut ar);
114 let mut br = BufReader::new(&mut br);
115
116 let a_send = Msg {
117 kind: "hello".to_string(),
118 n: 42,
119 };
120 write_json_line(&mut aw, &a_send).await.unwrap();
121 let b_recv: Msg = read_json_line(&mut br).await.unwrap().unwrap();
122 assert_eq!(b_recv, a_send);
123
124 let b_send = Msg {
125 kind: "world".to_string(),
126 n: 7,
127 };
128 write_json_line(&mut bw, &b_send).await.unwrap();
129 let a_recv: Msg = read_json_line(&mut ar).await.unwrap().unwrap();
130 assert_eq!(a_recv, b_send);
131 }
132
133 #[tokio::test]
134 async fn returns_invalid_data_on_bad_json() {
135 let (a, b) = tokio::io::duplex(1024);
136 let (mut _ar, mut aw) = tokio::io::split(a);
137 let (mut br, _bw) = tokio::io::split(b);
138 let mut br = BufReader::new(&mut br);
139
140 aw.write_all(b"{not json}\n").await.unwrap();
141
142 let err = read_json_line::<_, serde_json::Value>(&mut br)
143 .await
144 .unwrap_err();
145 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
146 }
147
148 #[tokio::test]
149 async fn errors_when_line_exceeds_limit() {
150 let (a, b) = tokio::io::duplex(1024 * 1024);
151 let (mut _ar, mut aw) = tokio::io::split(a);
152 let (mut br, _bw) = tokio::io::split(b);
153 let mut br = BufReader::new(&mut br);
154
155 let big = "a".repeat(33);
157 aw.write_all(big.as_bytes()).await.unwrap();
158 aw.write_all(b"\n").await.unwrap();
159
160 let err = read_json_line_with_limit::<_, serde_json::Value>(&mut br, 32)
161 .await
162 .unwrap_err();
163 assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
164 }
165
166 #[tokio::test]
167 async fn serve_jsonl_connection_handles_invalid_and_valid_requests() {
168 #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
169 struct Req {
170 n: u64,
171 }
172 #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
173 struct Resp {
174 ok: bool,
175 n: u64,
176 }
177
178 let (a, b) = UnixStream::pair().unwrap();
179 let h = tokio::spawn(async move {
180 serve_jsonl_connection(
181 a,
182 |req: Req| async move { Resp { ok: true, n: req.n } },
183 |_e| Resp { ok: false, n: 0 },
184 )
185 .await
186 .unwrap();
187 });
188
189 let (r, mut w) = b.into_split();
190 let mut r = BufReader::new(r);
191
192 w.write_all(b"{not json}\n").await.unwrap();
194 let resp: Resp = read_json_line(&mut r).await.unwrap().unwrap();
195 assert_eq!(resp, Resp { ok: false, n: 0 });
196
197 write_json_line(&mut w, &Req { n: 7 }).await.unwrap();
199 let resp: Resp = read_json_line(&mut r).await.unwrap().unwrap();
200 assert_eq!(resp, Resp { ok: true, n: 7 });
201
202 drop(w);
203 h.await.unwrap();
204 }
205}