Skip to main content

xs/client/
commands.rs

1use futures::StreamExt;
2
3use base64::Engine;
4use ssri::Integrity;
5
6use http_body_util::{combinators::BoxBody, BodyExt, Empty, StreamBody};
7use hyper::body::Bytes;
8use hyper::Method;
9use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
10use tokio::sync::mpsc::Receiver;
11use tokio_util::io::ReaderStream;
12
13use super::request;
14use crate::store::{ReadOptions, TTL};
15
16pub async fn cat(
17    addr: &str,
18    options: ReadOptions,
19    sse: bool,
20) -> Result<Receiver<Bytes>, Box<dyn std::error::Error + Send + Sync>> {
21    // Convert any usize limit to u64
22    let query = if options == ReadOptions::default() {
23        None
24    } else {
25        Some(options.to_query_string())
26    };
27
28    let headers = if sse {
29        Some(vec![(
30            "Accept".to_string(),
31            "text/event-stream".to_string(),
32        )])
33    } else {
34        None
35    };
36
37    let res = request::request(addr, Method::GET, "", query.as_deref(), empty(), headers).await?;
38
39    let (_parts, mut body) = res.into_parts();
40    let (tx, rx) = tokio::sync::mpsc::channel(100);
41
42    tokio::spawn(async move {
43        while let Some(frame_result) = body.frame().await {
44            match frame_result {
45                Ok(frame) => {
46                    if let Ok(bytes) = frame.into_data() {
47                        if tx.send(bytes).await.is_err() {
48                            break;
49                        }
50                    }
51                }
52                Err(e) => {
53                    eprintln!("Error reading body: {e}");
54                    break;
55                }
56            }
57        }
58    });
59
60    Ok(rx)
61}
62
63pub async fn append<R>(
64    addr: &str,
65    topic: &str,
66    data: R,
67    meta: Option<&serde_json::Value>,
68    ttl: Option<TTL>,
69) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>>
70where
71    R: AsyncRead + Unpin + Send + 'static,
72{
73    let query = ttl.map(|t| t.to_query());
74
75    let reader_stream = ReaderStream::new(data);
76    let mapped_stream = reader_stream.map(|result| {
77        result
78            .map(hyper::body::Frame::data)
79            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
80    });
81    let body = StreamBody::new(mapped_stream);
82
83    let headers = meta.map(|meta_value| {
84        let json_string = serde_json::to_string(meta_value).unwrap();
85        let encoded = base64::prelude::BASE64_STANDARD.encode(json_string);
86        vec![("xs-meta".to_string(), encoded)]
87    });
88
89    let res = request::request(
90        addr,
91        Method::POST,
92        &format!("append/{topic}"),
93        query.as_deref(),
94        body,
95        headers,
96    )
97    .await?;
98    let body = res.collect().await?.to_bytes();
99    Ok(body)
100}
101
102pub async fn cas_get<W>(
103    addr: &str,
104    integrity: Integrity,
105    writer: &mut W,
106) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
107where
108    W: AsyncWrite + Unpin,
109{
110    let parts = super::types::RequestParts::parse(addr, &format!("cas/{integrity}"), None)?;
111
112    match parts.connection {
113        super::types::ConnectionKind::Unix(path) => {
114            // Direct CAS access for local path
115            let store_path = path.parent().unwrap_or(&path).to_path_buf();
116            let cas_path = store_path.join("cacache");
117            match cacache::Reader::open_hash(&cas_path, integrity).await {
118                Ok(mut reader) => {
119                    tokio::io::copy(&mut reader, writer).await?;
120                    writer.flush().await?;
121                    Ok(())
122                }
123                Err(e) => {
124                    // Check if this is an entry not found error
125                    if matches!(e, cacache::Error::EntryNotFound(_, _)) {
126                        return Err(Box::new(crate::error::NotFound));
127                    }
128                    // Also check for IO not found errors in the chain
129                    let boxed_err: Box<dyn std::error::Error + Send + Sync> = Box::new(e);
130                    if crate::error::has_not_found_io_error(&boxed_err) {
131                        return Err(Box::new(crate::error::NotFound));
132                    }
133                    Err(boxed_err)
134                }
135            }
136        }
137        _ => {
138            // Remote HTTP access
139            let res = request::request(
140                addr,
141                Method::GET,
142                &format!("cas/{integrity}"),
143                None,
144                empty(),
145                None,
146            )
147            .await?;
148            let mut body = res.into_body();
149
150            while let Some(frame) = body.frame().await {
151                let frame = frame?;
152                if let Ok(chunk) = frame.into_data() {
153                    writer.write_all(&chunk).await?;
154                }
155            }
156
157            writer.flush().await?;
158            Ok(())
159        }
160    }
161}
162
163pub async fn cas_post<R>(
164    addr: &str,
165    data: R,
166) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>>
167where
168    R: AsyncRead + Unpin + Send + 'static,
169{
170    let reader_stream = ReaderStream::new(data);
171    let mapped_stream = reader_stream.map(|result| {
172        result
173            .map(hyper::body::Frame::data)
174            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
175    });
176    let body = StreamBody::new(mapped_stream);
177
178    let res = request::request(addr, Method::POST, "cas", None, body, None).await?;
179    let body = res.collect().await?.to_bytes();
180    Ok(body)
181}
182
183pub async fn get(addr: &str, id: &str) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>> {
184    let res = request::request(addr, Method::GET, id, None, empty(), None).await?;
185    let body = res.collect().await?.to_bytes();
186    Ok(body)
187}
188
189pub async fn remove(addr: &str, id: &str) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
190    let _ = request::request(addr, Method::DELETE, id, None, empty(), None).await?;
191    Ok(())
192}
193
194pub async fn last(
195    addr: &str,
196    topic: Option<&str>,
197    last: usize,
198    follow: bool,
199) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
200    let mut query_parts = Vec::new();
201    if last != 1 {
202        query_parts.push(format!("last={last}"));
203    }
204    if follow {
205        query_parts.push("follow=true".to_string());
206    }
207    let query = if query_parts.is_empty() {
208        None
209    } else {
210        Some(query_parts.join("&"))
211    };
212
213    let path = match topic {
214        Some(t) => format!("last/{t}"),
215        None => "last".to_string(),
216    };
217
218    let res = request::request(addr, Method::GET, &path, query.as_deref(), empty(), None).await?;
219
220    let mut body = res.into_body();
221    let mut stdout = tokio::io::stdout();
222
223    while let Some(frame) = body.frame().await {
224        let frame = frame?;
225        if let Ok(chunk) = frame.into_data() {
226            stdout.write_all(&chunk).await?;
227        }
228    }
229    stdout.flush().await?;
230    Ok(())
231}
232
233pub async fn import<R>(
234    addr: &str,
235    data: R,
236) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>>
237where
238    R: AsyncRead + Unpin + Send + 'static,
239{
240    let reader_stream = ReaderStream::new(data);
241    let mapped_stream = reader_stream.map(|result| {
242        result
243            .map(hyper::body::Frame::data)
244            .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Send + Sync>)
245    });
246    let body = StreamBody::new(mapped_stream);
247
248    let res = request::request(addr, Method::POST, "import", None, body, None).await?;
249    let body = res.collect().await?.to_bytes();
250    Ok(body)
251}
252
253pub async fn version(addr: &str) -> Result<Bytes, Box<dyn std::error::Error + Send + Sync>> {
254    match request::request(addr, Method::GET, "version", None, empty(), None).await {
255        Ok(res) => {
256            let body = res.collect().await?.to_bytes();
257            Ok(body)
258        }
259        Err(e) => {
260            // this was the version before the /version endpoint was added
261            if crate::error::NotFound::is_not_found(&e) {
262                Ok(Bytes::from(r#"{"version":"0.0.9"}"#))
263            } else {
264                Err(e)
265            }
266        }
267    }
268}
269
270fn empty() -> BoxBody<Bytes, Box<dyn std::error::Error + Send + Sync>> {
271    Empty::<Bytes>::new()
272        .map_err(|never| match never {})
273        .boxed()
274}
275
276pub async fn eval(
277    addr: &str,
278    script: String,
279) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
280    let res = request::request(addr, Method::POST, "eval", None, script, None).await?;
281
282    let mut body = res.into_body();
283    let mut stdout = tokio::io::stdout();
284
285    while let Some(frame) = body.frame().await {
286        let frame = frame?;
287        if let Ok(chunk) = frame.into_data() {
288            stdout.write_all(&chunk).await?;
289        }
290    }
291    stdout.flush().await?;
292    Ok(())
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use std::str::FromStr;
299    use tempfile::TempDir;
300
301    #[tokio::test]
302    async fn test_cas_get_not_found_local() {
303        let temp_dir = TempDir::new().unwrap();
304        let store_path = temp_dir.path().to_str().unwrap();
305
306        // Create a fake hash that doesn't exist
307        let fake_hash = "sha256-fakehashnotfound0000000000000000000000000000000=";
308        let integrity = Integrity::from_str(fake_hash).unwrap();
309
310        let mut output = Vec::new();
311        let result = cas_get(store_path, integrity, &mut output).await;
312
313        // Should return NotFound error
314        assert!(result.is_err());
315        let err = result.unwrap_err();
316        assert!(crate::error::NotFound::is_not_found(&err));
317    }
318}