xs/client/
commands.rs

1use futures::StreamExt;
2
3use base64::Engine;
4use ssri::Integrity;
5use url::form_urlencoded;
6
7use http_body_util::{combinators::BoxBody, BodyExt, Empty, StreamBody};
8use hyper::body::Bytes;
9use hyper::Method;
10use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
11use tokio::sync::mpsc::Receiver;
12use tokio_util::io::ReaderStream;
13
14use super::request;
15use crate::store::{ReadOptions, TTL};
16
17pub async fn cat(
18    addr: &str,
19    options: ReadOptions,
20    sse: bool,
21) -> Result<Receiver<Bytes>, Box<dyn std::error::Error + Send + Sync>> {
22    // Convert any usize limit to u64
23    let query = if options == ReadOptions::default() {
24        None
25    } else {
26        Some(options.to_query_string())
27    };
28
29    let headers = if sse {
30        Some(vec![(
31            "Accept".to_string(),
32            "text/event-stream".to_string(),
33        )])
34    } else {
35        None
36    };
37
38    let res = request::request(addr, Method::GET, "", query.as_deref(), empty(), headers).await?;
39
40    let (_parts, mut body) = res.into_parts();
41    let (tx, rx) = tokio::sync::mpsc::channel(100);
42
43    tokio::spawn(async move {
44        while let Some(frame_result) = body.frame().await {
45            match frame_result {
46                Ok(frame) => {
47                    if let Ok(bytes) = frame.into_data() {
48                        if tx.send(bytes).await.is_err() {
49                            break;
50                        }
51                    }
52                }
53                Err(e) => {
54                    eprintln!("Error reading body: {e}");
55                    break;
56                }
57            }
58        }
59    });
60
61    Ok(rx)
62}
63
64pub async fn append<R>(
65    addr: &str,
66    topic: &str,
67    data: R,
68    meta: Option<&serde_json::Value>,
69    ttl: Option<TTL>,
70    context: Option<&str>,
71) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>>
72where
73    R: AsyncRead + Unpin + Send + 'static,
74{
75    let mut params = Vec::new();
76    if let Some(t) = ttl {
77        let ttl_query = t.to_query();
78        if let Some((k, v)) = ttl_query.split_once('=') {
79            params.push((k.to_string(), v.to_string()));
80        }
81    }
82    if let Some(c) = context {
83        params.push(("context".to_string(), c.to_string()));
84    }
85
86    let query = if !params.is_empty() {
87        Some(
88            form_urlencoded::Serializer::new(String::new())
89                .extend_pairs(params)
90                .finish(),
91        )
92    } else {
93        None
94    };
95
96    let reader_stream = ReaderStream::new(data);
97    let mapped_stream = reader_stream.map(|result| {
98        result
99            .map(hyper::body::Frame::data)
100            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
101    });
102    let body = StreamBody::new(mapped_stream);
103
104    let headers = meta.map(|meta_value| {
105        let json_string = serde_json::to_string(meta_value).unwrap();
106        let encoded = base64::prelude::BASE64_STANDARD.encode(json_string);
107        vec![("xs-meta".to_string(), encoded)]
108    });
109
110    let res = request::request(addr, Method::POST, topic, query.as_deref(), body, headers).await?;
111    let body = res.collect().await?.to_bytes();
112    Ok(body)
113}
114
115pub async fn cas_get<W>(
116    addr: &str,
117    integrity: Integrity,
118    writer: &mut W,
119) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
120where
121    W: AsyncWrite + Unpin,
122{
123    let parts = super::types::RequestParts::parse(addr, &format!("cas/{integrity}"), None)?;
124
125    match parts.connection {
126        super::types::ConnectionKind::Unix(path) => {
127            // Direct CAS access for local path
128            let store_path = path.parent().unwrap_or(&path).to_path_buf();
129            let cas_path = store_path.join("cacache");
130            match cacache::Reader::open_hash(&cas_path, integrity).await {
131                Ok(mut reader) => {
132                    tokio::io::copy(&mut reader, writer).await?;
133                    writer.flush().await?;
134                    Ok(())
135                }
136                Err(e) => {
137                    // Check if this is an entry not found error
138                    if matches!(e, cacache::Error::EntryNotFound(_, _)) {
139                        return Err(Box::new(crate::error::NotFound));
140                    }
141                    // Also check for IO not found errors in the chain
142                    let boxed_err: Box<dyn std::error::Error + Send + Sync> = Box::new(e);
143                    if crate::error::has_not_found_io_error(&boxed_err) {
144                        return Err(Box::new(crate::error::NotFound));
145                    }
146                    Err(boxed_err)
147                }
148            }
149        }
150        _ => {
151            // Remote HTTP access
152            let res = request::request(
153                addr,
154                Method::GET,
155                &format!("cas/{integrity}"),
156                None,
157                empty(),
158                None,
159            )
160            .await?;
161            let mut body = res.into_body();
162
163            while let Some(frame) = body.frame().await {
164                let frame = frame?;
165                if let Ok(chunk) = frame.into_data() {
166                    writer.write_all(&chunk).await?;
167                }
168            }
169
170            writer.flush().await?;
171            Ok(())
172        }
173    }
174}
175
176pub async fn cas_post<R>(
177    addr: &str,
178    data: R,
179) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>>
180where
181    R: AsyncRead + Unpin + Send + 'static,
182{
183    let reader_stream = ReaderStream::new(data);
184    let mapped_stream = reader_stream.map(|result| {
185        result
186            .map(hyper::body::Frame::data)
187            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
188    });
189    let body = StreamBody::new(mapped_stream);
190
191    let res = request::request(addr, Method::POST, "cas", None, body, None).await?;
192    let body = res.collect().await?.to_bytes();
193    Ok(body)
194}
195
196pub async fn get(addr: &str, id: &str) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>> {
197    let res = request::request(addr, Method::GET, id, None, empty(), None).await?;
198    let body = res.collect().await?.to_bytes();
199    Ok(body)
200}
201
202pub async fn remove(addr: &str, id: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
203    let _ = request::request(addr, Method::DELETE, id, None, empty(), None).await?;
204    Ok(())
205}
206
207pub async fn head(
208    addr: &str,
209    topic: &str,
210    follow: bool,
211    context: Option<&str>,
212) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
213    let mut params = Vec::new();
214    if follow {
215        params.push(("follow", "true".to_string()));
216    }
217    if let Some(c) = context {
218        params.push(("context", c.to_string()));
219    }
220
221    let query = if !params.is_empty() {
222        Some(
223            form_urlencoded::Serializer::new(String::new())
224                .extend_pairs(params)
225                .finish(),
226        )
227    } else {
228        None
229    };
230
231    let res = request::request(
232        addr,
233        Method::GET,
234        &format!("head/{topic}"),
235        query.as_deref(),
236        empty(),
237        None,
238    )
239    .await?;
240
241    let mut body = res.into_body();
242    let mut stdout = tokio::io::stdout();
243
244    while let Some(frame) = body.frame().await {
245        let frame = frame?;
246        if let Ok(chunk) = frame.into_data() {
247            stdout.write_all(&chunk).await?;
248        }
249    }
250    stdout.flush().await?;
251    Ok(())
252}
253
254pub async fn import<R>(
255    addr: &str,
256    data: R,
257) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>>
258where
259    R: AsyncRead + Unpin + Send + 'static,
260{
261    let reader_stream = ReaderStream::new(data);
262    let mapped_stream = reader_stream.map(|result| {
263        result
264            .map(hyper::body::Frame::data)
265            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
266    });
267    let body = StreamBody::new(mapped_stream);
268
269    let res = request::request(addr, Method::POST, "import", None, body, None).await?;
270    let body = res.collect().await?.to_bytes();
271    Ok(body)
272}
273
274pub async fn version(addr: &str) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>> {
275    match request::request(addr, Method::GET, "version", None, empty(), None).await {
276        Ok(res) => {
277            let body = res.collect().await?.to_bytes();
278            Ok(body)
279        }
280        Err(e) => {
281            // this was the version before the /version endpoint was added
282            if crate::error::NotFound::is_not_found(&e) {
283                Ok(Bytes::from(r#"{"version":"0.0.9"}"#))
284            } else {
285                Err(e)
286            }
287        }
288    }
289}
290
291fn empty() -> BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>> {
292    Empty::<Bytes>::new()
293        .map_err(|never| match never {})
294        .boxed()
295}
296
297pub async fn exec(
298    addr: &str,
299    script: String,
300) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
301    let res = request::request(addr, Method::POST, "exec", None, script, None).await?;
302
303    let mut body = res.into_body();
304    let mut stdout = tokio::io::stdout();
305
306    while let Some(frame) = body.frame().await {
307        let frame = frame?;
308        if let Ok(chunk) = frame.into_data() {
309            stdout.write_all(&chunk).await?;
310        }
311    }
312    stdout.flush().await?;
313    Ok(())
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use std::str::FromStr;
320    use tempfile::TempDir;
321
322    #[tokio::test]
323    async fn test_cas_get_not_found_local() {
324        let temp_dir = TempDir::new().unwrap();
325        let store_path = temp_dir.path().to_str().unwrap();
326
327        // Create a fake hash that doesn't exist
328        let fake_hash = "sha256-fakehashnotfound0000000000000000000000000000000=";
329        let integrity = Integrity::from_str(fake_hash).unwrap();
330
331        let mut output = Vec::new();
332        let result = cas_get(store_path, integrity, &mut output).await;
333
334        // Should return NotFound error
335        assert!(result.is_err());
336        let err = result.unwrap_err();
337        assert!(crate::error::NotFound::is_not_found(&err));
338    }
339}