Skip to main content

cordyceps_api/
client.rs

1//! # Client
2//!
3
4use bytes::Bytes;
5pub use futures_util::stream::{Stream, StreamExt};
6use reqwest::Client as ReqwestClient;
7pub use reqwest::Result as ReqwestResult;
8use serde::Serialize;
9
10/// A Generic Error that will hopefully become more specific in the future.
11pub type Error = Box<dyn std::error::Error + std::marker::Send + std::marker::Sync>;
12
13/// A wrapper around [`Client`](Client) that is specific to chat [`Payloads`](crate::chat::Payload).
14#[cfg(feature = "chat")]
15pub struct ChatClient(Client<crate::chat::Payload>);
16
17#[cfg(feature = "chat")]
18impl ChatClient {
19    pub fn new(api_key: impl Into<String>) -> Self {
20        Self(Client::new(api_key.into(), crate::chat::API_URL))
21    }
22
23    pub async fn send(
24        &self,
25        payload: &crate::chat::Payload,
26    ) -> Result<impl Stream<Item = ReqwestResult<Bytes>>, Error> {
27        self.0.send(payload).await
28    }
29}
30
31/// A generic client for sending json payloads to OpenAi's API.
32pub struct Client<P: Serialize + ?Sized> {
33    api_key: String,
34    api_url: String,
35
36    marker: std::marker::PhantomData<P>,
37}
38
39impl<P: Serialize + ?Sized> Client<P> {
40    pub fn new(api_key: impl Into<String>, api_url: impl Into<String>) -> Self {
41        Self {
42            api_key: api_key.into(),
43            api_url: api_url.into(),
44            marker: std::marker::PhantomData,
45        }
46    }
47
48    /// Sends a payload to the API. Returns a stream of bytes that can be asynchronously awaited.
49    pub async fn send(
50        &self,
51        payload: &P,
52    ) -> Result<impl Stream<Item = ReqwestResult<Bytes>>, Error> {
53        let req = ReqwestClient::new()
54            .post(&self.api_url)
55            .bearer_auth(&self.api_key)
56            .json(&payload)
57            .send()
58            .await?;
59
60        if !req.status().is_success() {
61            return Err(format!(
62                "Could not request openai with status code: {}",
63                req.status()
64            )
65            .into());
66        }
67
68        let resp = req.bytes_stream().filter_map(|result| async move {
69            match result {
70                Ok(bytes) => Some(Ok(bytes.slice(6..))), // Removes the b"data: " prefix. Thank you
71                // openai!
72                Err(_) => Some(result),
73            }
74        });
75
76        Ok(Box::pin(resp))
77    }
78}