xs/client/
commands.rs

1use futures::StreamExt;
2
3use base64::Engine;
4use ssri::Integrity;
5use url::form_urlencoded;
6
7use http_body_util::{combinators::BoxBody, BodyExt, Empty, StreamBody};
8use hyper::body::Bytes;
9use hyper::Method;
10use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
11use tokio::sync::mpsc::Receiver;
12use tokio_util::io::ReaderStream;
13
14use super::request;
15use crate::store::{ReadOptions, TTL};
16
17pub async fn cat(
18    addr: &str,
19    options: ReadOptions,
20    sse: bool,
21) -> Result<Receiver<Bytes>, Box<dyn std::error::Error + Send + Sync>> {
22    // Convert any usize limit to u64
23    let query = if options == ReadOptions::default() {
24        None
25    } else {
26        Some(options.to_query_string())
27    };
28
29    let headers = if sse {
30        Some(vec![(
31            "Accept".to_string(),
32            "text/event-stream".to_string(),
33        )])
34    } else {
35        None
36    };
37
38    let res = request::request(addr, Method::GET, "", query.as_deref(), empty(), headers).await?;
39
40    let (_parts, mut body) = res.into_parts();
41    let (tx, rx) = tokio::sync::mpsc::channel(100);
42
43    tokio::spawn(async move {
44        while let Some(frame_result) = body.frame().await {
45            match frame_result {
46                Ok(frame) => {
47                    if let Ok(bytes) = frame.into_data() {
48                        if tx.send(bytes).await.is_err() {
49                            break;
50                        }
51                    }
52                }
53                Err(e) => {
54                    eprintln!("Error reading body: {e}");
55                    break;
56                }
57            }
58        }
59    });
60
61    Ok(rx)
62}
63
64pub async fn append<R>(
65    addr: &str,
66    topic: &str,
67    data: R,
68    meta: Option<&serde_json::Value>,
69    ttl: Option<TTL>,
70    context: Option<&str>,
71) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>>
72where
73    R: AsyncRead + Unpin + Send + 'static,
74{
75    let mut params = Vec::new();
76    if let Some(t) = ttl {
77        let ttl_query = t.to_query();
78        if let Some((k, v)) = ttl_query.split_once('=') {
79            params.push((k.to_string(), v.to_string()));
80        }
81    }
82    if let Some(c) = context {
83        params.push(("context".to_string(), c.to_string()));
84    }
85
86    let query = if !params.is_empty() {
87        Some(
88            form_urlencoded::Serializer::new(String::new())
89                .extend_pairs(params)
90                .finish(),
91        )
92    } else {
93        None
94    };
95
96    let reader_stream = ReaderStream::new(data);
97    let mapped_stream = reader_stream.map(|result| {
98        result
99            .map(hyper::body::Frame::data)
100            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
101    });
102    let body = StreamBody::new(mapped_stream);
103
104    let headers = meta.map(|meta_value| {
105        let json_string = serde_json::to_string(meta_value).unwrap();
106        let encoded = base64::prelude::BASE64_STANDARD.encode(json_string);
107        vec![("xs-meta".to_string(), encoded)]
108    });
109
110    let res = request::request(addr, Method::POST, topic, query.as_deref(), body, headers).await?;
111    let body = res.collect().await?.to_bytes();
112    Ok(body)
113}
114
115pub async fn cas_get<W>(
116    addr: &str,
117    integrity: Integrity,
118    writer: &mut W,
119) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
120where
121    W: AsyncWrite + Unpin,
122{
123    let parts = super::types::RequestParts::parse(addr, &format!("cas/{integrity}"), None)?;
124
125    match parts.connection {
126        super::types::ConnectionKind::Unix(path) => {
127            // Direct CAS access for local path
128            let store_path = path.parent().unwrap_or(&path).to_path_buf();
129            let cas_path = store_path.join("cacache");
130            match cacache::Reader::open_hash(&cas_path, integrity).await {
131                Ok(mut reader) => {
132                    tokio::io::copy(&mut reader, writer).await?;
133                    writer.flush().await?;
134                    Ok(())
135                }
136                Err(e) => {
137                    // Check if this is a file not found error and convert to NotFound
138                    let boxed_err: Box<dyn std::error::Error + Send + Sync> = Box::new(e);
139                    if crate::error::has_not_found_io_error(&boxed_err) {
140                        return Err(Box::new(crate::error::NotFound));
141                    }
142                    Err(boxed_err)
143                }
144            }
145        }
146        _ => {
147            // Remote HTTP access
148            let res = request::request(
149                addr,
150                Method::GET,
151                &format!("cas/{integrity}"),
152                None,
153                empty(),
154                None,
155            )
156            .await?;
157            let mut body = res.into_body();
158
159            while let Some(frame) = body.frame().await {
160                let frame = frame?;
161                if let Ok(chunk) = frame.into_data() {
162                    writer.write_all(&chunk).await?;
163                }
164            }
165
166            writer.flush().await?;
167            Ok(())
168        }
169    }
170}
171
172pub async fn cas_post<R>(
173    addr: &str,
174    data: R,
175) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>>
176where
177    R: AsyncRead + Unpin + Send + 'static,
178{
179    let reader_stream = ReaderStream::new(data);
180    let mapped_stream = reader_stream.map(|result| {
181        result
182            .map(hyper::body::Frame::data)
183            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
184    });
185    let body = StreamBody::new(mapped_stream);
186
187    let res = request::request(addr, Method::POST, "cas", None, body, None).await?;
188    let body = res.collect().await?.to_bytes();
189    Ok(body)
190}
191
192pub async fn get(addr: &str, id: &str) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>> {
193    let res = request::request(addr, Method::GET, id, None, empty(), None).await?;
194    let body = res.collect().await?.to_bytes();
195    Ok(body)
196}
197
198pub async fn remove(addr: &str, id: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
199    let _ = request::request(addr, Method::DELETE, id, None, empty(), None).await?;
200    Ok(())
201}
202
203pub async fn head(
204    addr: &str,
205    topic: &str,
206    follow: bool,
207    context: Option<&str>,
208) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
209    let mut params = Vec::new();
210    if follow {
211        params.push(("follow", "true".to_string()));
212    }
213    if let Some(c) = context {
214        params.push(("context", c.to_string()));
215    }
216
217    let query = if !params.is_empty() {
218        Some(
219            form_urlencoded::Serializer::new(String::new())
220                .extend_pairs(params)
221                .finish(),
222        )
223    } else {
224        None
225    };
226
227    let res = request::request(
228        addr,
229        Method::GET,
230        &format!("head/{topic}"),
231        query.as_deref(),
232        empty(),
233        None,
234    )
235    .await?;
236
237    let mut body = res.into_body();
238    let mut stdout = tokio::io::stdout();
239
240    while let Some(frame) = body.frame().await {
241        let frame = frame?;
242        if let Ok(chunk) = frame.into_data() {
243            stdout.write_all(&chunk).await?;
244        }
245    }
246    stdout.flush().await?;
247    Ok(())
248}
249
250pub async fn import<R>(
251    addr: &str,
252    data: R,
253) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>>
254where
255    R: AsyncRead + Unpin + Send + 'static,
256{
257    let reader_stream = ReaderStream::new(data);
258    let mapped_stream = reader_stream.map(|result| {
259        result
260            .map(hyper::body::Frame::data)
261            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
262    });
263    let body = StreamBody::new(mapped_stream);
264
265    let res = request::request(addr, Method::POST, "import", None, body, None).await?;
266    let body = res.collect().await?.to_bytes();
267    Ok(body)
268}
269
270pub async fn version(addr: &str) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>> {
271    match request::request(addr, Method::GET, "version", None, empty(), None).await {
272        Ok(res) => {
273            let body = res.collect().await?.to_bytes();
274            Ok(body)
275        }
276        Err(e) => {
277            // this was the version before the /version endpoint was added
278            if crate::error::NotFound::is_not_found(&e) {
279                Ok(Bytes::from(r#"{"version":"0.0.9"}"#))
280            } else {
281                Err(e)
282            }
283        }
284    }
285}
286
287fn empty() -> BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>> {
288    Empty::<Bytes>::new()
289        .map_err(|never| match never {})
290        .boxed()
291}
292
293pub async fn exec(
294    addr: &str,
295    script: String,
296) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
297    let res = request::request(addr, Method::POST, "exec", None, script, None).await?;
298
299    let mut body = res.into_body();
300    let mut stdout = tokio::io::stdout();
301
302    while let Some(frame) = body.frame().await {
303        let frame = frame?;
304        if let Ok(chunk) = frame.into_data() {
305            stdout.write_all(&chunk).await?;
306        }
307    }
308    stdout.flush().await?;
309    Ok(())
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use std::str::FromStr;
316    use tempfile::TempDir;
317
318    #[tokio::test]
319    async fn test_cas_get_not_found_local() {
320        let temp_dir = TempDir::new().unwrap();
321        let store_path = temp_dir.path().to_str().unwrap();
322
323        // Create a fake hash that doesn't exist
324        let fake_hash = "sha256-fakehashnotfound0000000000000000000000000000000=";
325        let integrity = Integrity::from_str(fake_hash).unwrap();
326
327        let mut output = Vec::new();
328        let result = cas_get(store_path, integrity, &mut output).await;
329
330        // Should return NotFound error
331        assert!(result.is_err());
332        let err = result.unwrap_err();
333        assert!(crate::error::NotFound::is_not_found(&err));
334    }
335}