Skip to main content

tako_socket/
lib.rs

1use serde::Serialize;
2use serde::de::DeserializeOwned;
3use std::future::Future;
4use tokio::io::BufReader;
5use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWrite, AsyncWriteExt};
6use tokio::net::UnixStream;
7
8pub const DEFAULT_MAX_LINE_BYTES: usize = 1024 * 1024;
9
10pub async fn read_json_line_with_limit<R, T>(
11    reader: &mut R,
12    max_bytes: usize,
13) -> std::io::Result<Option<T>>
14where
15    R: AsyncBufRead + Unpin,
16    T: DeserializeOwned,
17{
18    let mut buf = Vec::new();
19    let n = reader.read_until(b'\n', &mut buf).await?;
20    if n == 0 {
21        return Ok(None);
22    }
23    if buf.len() > max_bytes {
24        return Err(std::io::Error::new(
25            std::io::ErrorKind::InvalidData,
26            format!(
27                "json line exceeds max length ({} > {})",
28                buf.len(),
29                max_bytes
30            ),
31        ));
32    }
33
34    let s = std::str::from_utf8(&buf)
35        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
36
37    serde_json::from_str::<T>(s)
38        .map(Some)
39        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
40}
41
42pub async fn read_json_line<R, T>(reader: &mut R) -> std::io::Result<Option<T>>
43where
44    R: AsyncBufRead + Unpin,
45    T: DeserializeOwned,
46{
47    read_json_line_with_limit(reader, DEFAULT_MAX_LINE_BYTES).await
48}
49
50pub async fn write_json_line<W, T>(writer: &mut W, value: &T) -> std::io::Result<()>
51where
52    W: AsyncWrite + Unpin,
53    T: Serialize,
54{
55    let json = serde_json::to_string(value)
56        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
57    writer.write_all(json.as_bytes()).await?;
58    writer.write_all(b"\n").await?;
59    Ok(())
60}
61
62pub async fn serve_jsonl_connection<Req, Resp, F, Fut, InvalidResp>(
63    stream: UnixStream,
64    handler: F,
65    invalid_response: InvalidResp,
66) -> std::io::Result<()>
67where
68    Req: DeserializeOwned,
69    Resp: Serialize,
70    F: Fn(Req) -> Fut,
71    Fut: Future<Output = Resp>,
72    InvalidResp: Fn(std::io::Error) -> Resp,
73{
74    let (reader, mut writer) = stream.into_split();
75    let mut reader = BufReader::new(reader);
76
77    loop {
78        let Some(req) = (match read_json_line::<_, Req>(&mut reader).await {
79            Ok(v) => v,
80            Err(e) if e.kind() == std::io::ErrorKind::InvalidData => {
81                let resp = invalid_response(e);
82                let _ = write_json_line(&mut writer, &resp).await;
83                continue;
84            }
85            Err(e) => return Err(e),
86        }) else {
87            break;
88        };
89
90        let resp = handler(req).await;
91        write_json_line(&mut writer, &resp).await?;
92    }
93
94    Ok(())
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use tokio::io::BufReader;
101
102    #[tokio::test]
103    async fn roundtrips_struct_over_jsonl() {
104        #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
105        struct Msg {
106            kind: String,
107            n: u64,
108        }
109
110        let (a, b) = tokio::io::duplex(1024);
111        let (mut ar, mut aw) = tokio::io::split(a);
112        let (mut br, mut bw) = tokio::io::split(b);
113        let mut ar = BufReader::new(&mut ar);
114        let mut br = BufReader::new(&mut br);
115
116        let a_send = Msg {
117            kind: "hello".to_string(),
118            n: 42,
119        };
120        write_json_line(&mut aw, &a_send).await.unwrap();
121        let b_recv: Msg = read_json_line(&mut br).await.unwrap().unwrap();
122        assert_eq!(b_recv, a_send);
123
124        let b_send = Msg {
125            kind: "world".to_string(),
126            n: 7,
127        };
128        write_json_line(&mut bw, &b_send).await.unwrap();
129        let a_recv: Msg = read_json_line(&mut ar).await.unwrap().unwrap();
130        assert_eq!(a_recv, b_send);
131    }
132
133    #[tokio::test]
134    async fn returns_invalid_data_on_bad_json() {
135        let (a, b) = tokio::io::duplex(1024);
136        let (mut _ar, mut aw) = tokio::io::split(a);
137        let (mut br, _bw) = tokio::io::split(b);
138        let mut br = BufReader::new(&mut br);
139
140        aw.write_all(b"{not json}\n").await.unwrap();
141
142        let err = read_json_line::<_, serde_json::Value>(&mut br)
143            .await
144            .unwrap_err();
145        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
146    }
147
148    #[tokio::test]
149    async fn errors_when_line_exceeds_limit() {
150        let (a, b) = tokio::io::duplex(1024 * 1024);
151        let (mut _ar, mut aw) = tokio::io::split(a);
152        let (mut br, _bw) = tokio::io::split(b);
153        let mut br = BufReader::new(&mut br);
154
155        // Write a line bigger than our limit.
156        let big = "a".repeat(33);
157        aw.write_all(big.as_bytes()).await.unwrap();
158        aw.write_all(b"\n").await.unwrap();
159
160        let err = read_json_line_with_limit::<_, serde_json::Value>(&mut br, 32)
161            .await
162            .unwrap_err();
163        assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
164    }
165
166    #[tokio::test]
167    async fn serve_jsonl_connection_handles_invalid_and_valid_requests() {
168        #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
169        struct Req {
170            n: u64,
171        }
172        #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
173        struct Resp {
174            ok: bool,
175            n: u64,
176        }
177
178        let (a, b) = UnixStream::pair().unwrap();
179        let h = tokio::spawn(async move {
180            serve_jsonl_connection(
181                a,
182                |req: Req| async move { Resp { ok: true, n: req.n } },
183                |_e| Resp { ok: false, n: 0 },
184            )
185            .await
186            .unwrap();
187        });
188
189        let (r, mut w) = b.into_split();
190        let mut r = BufReader::new(r);
191
192        // Invalid JSON should yield an error response.
193        w.write_all(b"{not json}\n").await.unwrap();
194        let resp: Resp = read_json_line(&mut r).await.unwrap().unwrap();
195        assert_eq!(resp, Resp { ok: false, n: 0 });
196
197        // Valid JSON should roundtrip.
198        write_json_line(&mut w, &Req { n: 7 }).await.unwrap();
199        let resp: Resp = read_json_line(&mut r).await.unwrap().unwrap();
200        assert_eq!(resp, Resp { ok: true, n: 7 });
201
202        drop(w);
203        h.await.unwrap();
204    }
205}