1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
use std::time::Duration;

use api::{
    chat::{ChatRequest, ChatStreamRequest, ChatStreamResponse},
    classify::{Classification, ClassifyRequest, ClassifyResponse},
    detect_language::{DetectLanguageRequest, DetectLanguageResponse, DetectLanguageResult},
    detokenize::{DetokenizeRequest, DetokenizeResponse},
    embed::{EmbedRequest, EmbedResponse},
    generate::{GenerateRequest, GenerateResponse, Generation},
    rerank::{ReRankRequest, ReRankResponse, ReRankResult},
    summarize::{SummarizeRequest, SummarizeResponse},
    tokenize::{TokenizeRequest, TokenizeResponse},
};
use reqwest::{header, ClientBuilder, StatusCode, Url};
use tokio::sync::mpsc::{channel, Receiver};

const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
const COHERE_API_LATEST_VERSION: &str = "2022-12-06";
const COHERE_API_V1: &str = "v1";

use serde::{de::DeserializeOwned, Deserialize, Serialize};
use thiserror::Error;

pub mod api;

#[derive(Error, Debug)]
pub enum CohereApiError {
    #[error("Unexpected request error")]
    RequestError(#[from] reqwest::Error),
    #[error("API request failed with status code `{0}` and error message `{1}`")]
    ApiError(StatusCode, String),
    #[error("API key is invalid")]
    InvalidApiKey,
    #[error("Unknown error")]
    Unknown,
}

#[derive(Error, Debug)]
pub enum CohereStreamError {
    #[error("Unexpected deserialization error")]
    RequestError(#[from] serde_json::error::Error),
    #[error("Unknown error `{0}`")]
    Unknown(String),
}

/// Cohere Rust SDK to build natural language understanding and generation into your product with a few lines of code.
pub struct Cohere {
    api_url: String,
    client: reqwest::Client,
}

#[derive(Deserialize, Debug)]
struct CohereCheckApiKeyResponse {
    valid: bool,
}

#[derive(Deserialize, Debug)]
struct CohereApiErrorResponse {
    message: String,
}

impl Default for Cohere {
    fn default() -> Self {
        let api_key = std::env::var("COHERE_API_KEY")
            .expect("please provide a Cohere API key with the 'COHERE_API_KEY' env variable");
        Cohere::new(
            format!("{COHERE_API_BASE_URL}/{COHERE_API_V1}"),
            api_key,
            COHERE_API_LATEST_VERSION,
        )
    }
}

impl Cohere {
    pub fn new<U: Into<String>, K: Into<String>, V: Into<String>>(
        api_url: U,
        api_key: K,
        version: V,
    ) -> Self {
        let api_url: String = api_url.into();
        let api_key: String = api_key.into();
        let version: String = version.into();

        let mut headers = header::HeaderMap::new();

        let mut authorization = header::HeaderValue::from_str(&format!("Bearer {api_key}"))
            .expect("failed to construct authorization header!");
        authorization.set_sensitive(true);
        headers.insert(header::AUTHORIZATION, authorization);

        headers.insert(
            "Cohere-Version",
            header::HeaderValue::from_str(&version)
                .expect("failed to construct cohere version header!"),
        );

        headers.insert(
            "Request-Source",
            header::HeaderValue::from_static("rust-sdk"),
        );

        headers.insert(
            header::ACCEPT,
            header::HeaderValue::from_static("application/json"),
        );
        headers.insert(
            header::CONTENT_TYPE,
            header::HeaderValue::from_static("application/json"),
        );

        let client = ClientBuilder::new()
            .default_headers(headers)
            .use_rustls_tls()
            .timeout(Duration::from_secs(90))
            .build()
            .expect("failed to initialize HTTP client!");

        Cohere { api_url, client }
    }

    async fn request<Request: Serialize, Response: DeserializeOwned>(
        &self,
        route: &'static str,
        payload: Request,
    ) -> Result<Response, CohereApiError> {
        let url =
            Url::parse(&format!("{}/{route}", self.api_url)).expect("api url should be valid");

        let response = self.client.post(url).json(&payload).send().await?;

        // Check for any API Warnings
        if let Some(warning) = response.headers().get("X-API-Warning") {
            eprintln!("Warning: {:?}", warning.as_bytes());
        }

        if response.status().is_client_error() || response.status().is_server_error() {
            let status = response.status();
            let text = response.text().await?;
            Err(CohereApiError::ApiError(
                status,
                serde_json::from_str::<CohereApiErrorResponse>(&text)
                    .unwrap_or(CohereApiErrorResponse {
                        message: format!("Unknown API Error: {}", text),
                    })
                    .message,
            ))
        } else {
            Ok(response.json::<Response>().await?)
        }
    }

    async fn request_stream<Request: Serialize, Response: DeserializeOwned + Send + 'static>(
        &self,
        route: &'static str,
        payload: Request,
    ) -> Result<Receiver<Result<Response, CohereStreamError>>, CohereApiError> {
        let url =
            Url::parse(&format!("{}/{route}", self.api_url)).expect("api url should be valid");

        let mut response = self.client.post(url).json(&payload).send().await?;

        let (tx, rx) = channel::<Result<Response, CohereStreamError>>(32);
        tokio::spawn(async move {
            while let Ok(Some(chunk)) = response.chunk().await {
                if chunk.is_empty() {
                    break;
                }
                match serde_json::from_slice::<Response>(&chunk) {
                    Ok(v) => tx
                        .send(Ok(v))
                        .await
                        .expect("Failed to send message to channel"),
                    Err(e) => tx
                        .send(Err(CohereStreamError::from(e)))
                        .await
                        .expect("Failed to send error to channel"),
                }
            }
        });

        Ok(rx)
    }

    /// Verify that the Cohere API key being used is valid
    pub async fn check_api_key(&self) -> Result<(), CohereApiError> {
        let response = self
            .request::<(), CohereCheckApiKeyResponse>("check-api-key", ())
            .await?;

        match response.valid {
            true => Ok(()),
            false => Err(CohereApiError::InvalidApiKey),
        }
    }

    /// Generates realistic text conditioned on a given input.
    pub async fn generate<'input>(
        &self,
        request: &GenerateRequest<'input>,
    ) -> Result<Vec<Generation>, CohereApiError> {
        let response = self
            .request::<_, GenerateResponse>("generate", request)
            .await?;

        Ok(response.generations)
    }

    /// Chat with Cohere's LLM
    pub async fn chat<'input>(
        &self,
        request: &ChatRequest<'input>,
    ) -> Result<Receiver<Result<ChatStreamResponse, CohereStreamError>>, CohereApiError> {
        let stream_request = ChatStreamRequest {
            request,
            stream: true,
        };
        let response = self
            .request_stream::<_, ChatStreamResponse>("chat", stream_request)
            .await?;

        Ok(response)
    }

    /// Returns text embeddings.
    /// An embedding is a list of floating point numbers that captures semantic information about the text that it represents.
    /// Embeddings can be used to create text classifiers as well as empower semantic search.
    pub async fn embed<'input>(
        &self,
        request: &EmbedRequest<'input>,
    ) -> Result<Vec<Vec<f64>>, CohereApiError> {
        let response = self.request::<_, EmbedResponse>("embed", request).await?;

        Ok(response.embeddings)
    }

    /// Makes a prediction about which label fits the specified text inputs best.
    /// To make a prediction, classify uses the provided examples of text + label pairs as a reference.
    pub async fn classify<'input>(
        &self,
        request: &ClassifyRequest<'input>,
    ) -> Result<Vec<Classification>, CohereApiError> {
        let response = self
            .request::<_, ClassifyResponse>("classify", request)
            .await?;

        Ok(response.classifications)
    }

    /// Generates a summary in English for a given text.
    pub async fn summarize<'input>(
        &self,
        request: &SummarizeRequest<'input>,
    ) -> Result<String, CohereApiError> {
        let response = self
            .request::<_, SummarizeResponse>("summarize", request)
            .await?;

        Ok(response.summary)
    }

    /// Splits input text into smaller units called tokens using byte-pair encoding (BPE).
    pub async fn tokenize<'input>(
        &self,
        request: &TokenizeRequest<'input>,
    ) -> Result<TokenizeResponse, CohereApiError> {
        let response = self.request("tokenize", request).await?;

        Ok(response)
    }

    /// Takes tokens using byte-pair encoding and returns their text representation.
    pub async fn detokenize<'input>(
        &self,
        request: &DetokenizeRequest<'input>,
    ) -> Result<String, CohereApiError> {
        let response = self
            .request::<_, DetokenizeResponse>("detokenize", request)
            .await?;

        Ok(response.text)
    }

    /// Identifies which language each of the provided texts is written in
    pub async fn detect_language<'input>(
        &self,
        request: &DetectLanguageRequest<'input>,
    ) -> Result<Vec<DetectLanguageResult>, CohereApiError> {
        let response = self
            .request::<_, DetectLanguageResponse>("detect-language", request)
            .await?;

        Ok(response.results)
    }

    /// Takes a query plus an list of texts and return an ordered array with each text assigned a relevance score.
    pub async fn rerank<'input>(
        &self,
        request: &ReRankRequest<'input>,
    ) -> Result<Vec<ReRankResult>, CohereApiError> {
        let response = self.request::<_, ReRankResponse>("rerank", request).await?;

        Ok(response.results)
    }
}