pharia_skill_test/
lib.rs

1use std::time::Duration;
2
3use pharia_skill::{
4    ChatRequest, ChatResponse, ChunkRequest, Completion, CompletionRequest, Csi, Document,
5    DocumentPath, FinishReason, LanguageCode, Message, SearchRequest, SearchResult,
6    SelectLanguageRequest, TokenUsage,
7};
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use ureq::{json, serde_json::Value, Agent, AgentBuilder};
10
11pub struct StubCsi;
12
13impl Csi for StubCsi {
14    fn chat_concurrently(&self, requests: Vec<ChatRequest>) -> Vec<ChatResponse> {
15        requests
16            .iter()
17            .map(|_| ChatResponse {
18                message: Message::new("user", ""),
19                finish_reason: FinishReason::Stop,
20                logprobs: vec![],
21                usage: TokenUsage {
22                    prompt: 0,
23                    completion: 0,
24                },
25            })
26            .collect()
27    }
28
29    fn complete_concurrently(&self, requests: Vec<CompletionRequest>) -> Vec<Completion> {
30        requests
31            .into_iter()
32            .map(|request| Completion {
33                text: request.prompt,
34                finish_reason: FinishReason::Stop,
35                logprobs: vec![],
36                usage: TokenUsage {
37                    prompt: 0,
38                    completion: 0,
39                },
40            })
41            .collect()
42    }
43
44    fn chunk_concurrently(&self, requests: Vec<ChunkRequest>) -> Vec<Vec<String>> {
45        requests
46            .into_iter()
47            .map(|request| vec![request.text])
48            .collect()
49    }
50
51    fn select_language_concurrently(
52        &self,
53        requests: Vec<SelectLanguageRequest>,
54    ) -> Vec<Option<LanguageCode>> {
55        requests.iter().map(|_| None).collect()
56    }
57
58    fn search_concurrently(&self, _requests: Vec<SearchRequest>) -> Vec<Vec<SearchResult>> {
59        vec![]
60    }
61
62    fn documents<Metadata>(
63        &self,
64        _paths: Vec<DocumentPath>,
65    ) -> anyhow::Result<Vec<Document<Metadata>>>
66    where
67        Metadata: for<'a> Deserialize<'a>,
68    {
69        Ok(vec![])
70    }
71
72    fn documents_metadata<Metadata>(
73        &self,
74        _paths: Vec<DocumentPath>,
75    ) -> anyhow::Result<Vec<Option<Metadata>>>
76    where
77        Metadata: for<'a> Deserialize<'a>,
78    {
79        Ok(vec![])
80    }
81}
82
83pub struct MockCsi {
84    response: String,
85}
86
87impl MockCsi {
88    #[must_use]
89    pub fn new(response: impl Into<String>) -> Self {
90        Self {
91            response: response.into(),
92        }
93    }
94}
95
96impl Csi for MockCsi {
97    fn chat_concurrently(&self, requests: Vec<ChatRequest>) -> Vec<ChatResponse> {
98        requests
99            .iter()
100            .map(|_| ChatResponse {
101                message: Message::new("user", self.response.clone()),
102                finish_reason: FinishReason::Stop,
103                logprobs: vec![],
104                usage: TokenUsage {
105                    prompt: 0,
106                    completion: 0,
107                },
108            })
109            .collect()
110    }
111
112    fn complete_concurrently(&self, requests: Vec<CompletionRequest>) -> Vec<Completion> {
113        requests
114            .iter()
115            .map(|_| Completion {
116                text: self.response.clone(),
117                finish_reason: FinishReason::Stop,
118                logprobs: vec![],
119                usage: TokenUsage {
120                    prompt: 0,
121                    completion: 0,
122                },
123            })
124            .collect()
125    }
126
127    fn chunk_concurrently(&self, requests: Vec<ChunkRequest>) -> Vec<Vec<String>> {
128        requests
129            .into_iter()
130            .map(|request| vec![request.text])
131            .collect()
132    }
133
134    fn select_language_concurrently(
135        &self,
136        requests: Vec<SelectLanguageRequest>,
137    ) -> Vec<Option<LanguageCode>> {
138        requests.iter().map(|_| None).collect()
139    }
140
141    fn search_concurrently(&self, _requests: Vec<SearchRequest>) -> Vec<Vec<SearchResult>> {
142        vec![]
143    }
144
145    fn documents<Metadata>(
146        &self,
147        _paths: Vec<DocumentPath>,
148    ) -> anyhow::Result<Vec<Document<Metadata>>>
149    where
150        Metadata: for<'a> Deserialize<'a>,
151    {
152        Ok(vec![])
153    }
154
155    fn documents_metadata<Metadata>(
156        &self,
157        _paths: Vec<DocumentPath>,
158    ) -> anyhow::Result<Vec<Option<Metadata>>>
159    where
160        Metadata: for<'a> Deserialize<'a>,
161    {
162        Ok(vec![])
163    }
164}
165
166#[derive(Copy, Clone, Debug, Serialize)]
167#[serde(rename_all = "snake_case")]
168enum Function {
169    Complete,
170    Chunk,
171    SelectLanguage,
172    Search,
173    Chat,
174    Documents,
175    DocumentMetadata,
176}
177
178#[derive(Serialize)]
179struct CsiRequest<'a, P: Serialize> {
180    version: &'a str,
181    function: Function,
182    #[serde(flatten)]
183    payload: P,
184}
185
186/// A Csi implementation that can be used for testing within normal Rust targets.
187pub struct DevCsi {
188    address: String,
189    agent: Agent,
190    token: String,
191}
192
193impl DevCsi {
194    /// The version of the API we are calling against
195    const VERSION: &str = "0.3";
196
197    #[must_use]
198    pub fn new(address: impl Into<String>, token: impl Into<String>) -> Self {
199        let agent = AgentBuilder::new()
200            .timeout(Duration::from_secs(60 * 5))
201            .build();
202        Self {
203            address: address.into(),
204            agent,
205            token: token.into(),
206        }
207    }
208
209    /// Construct a new [`DevCsi`] that points to the Aleph Alpha hosted Kernel
210    pub fn aleph_alpha(token: impl Into<String>) -> Self {
211        Self::new("https://pharia-kernel.product.pharia.com", token)
212    }
213
214    fn csi_request<R: DeserializeOwned>(
215        &self,
216        function: Function,
217        payload: impl Serialize,
218    ) -> anyhow::Result<R> {
219        let json = CsiRequest {
220            version: Self::VERSION,
221            function,
222            payload,
223        };
224        let response = self
225            .agent
226            .post(&format!("{}/csi", &self.address))
227            .set("Authorization", &format!("Bearer {}", self.token))
228            .send_json(json);
229
230        match response {
231            Ok(response) => Ok(response.into_json::<R>()?),
232            Err(ureq::Error::Status(status, response)) => {
233                panic!(
234                    "Failed Request: Status {status} {}",
235                    response.into_json::<Value>().unwrap_or_default()
236                );
237            }
238            Err(e) => {
239                panic!("{e}")
240            }
241        }
242    }
243}
244
245impl Csi for DevCsi {
246    fn chat_concurrently(&self, requests: Vec<ChatRequest>) -> Vec<ChatResponse> {
247        self.csi_request(Function::Chat, json!({"requests": requests}))
248            .unwrap()
249    }
250
251    fn complete_concurrently(&self, requests: Vec<CompletionRequest>) -> Vec<Completion> {
252        self.csi_request(Function::Complete, json!({"requests": requests}))
253            .unwrap()
254    }
255
256    fn chunk_concurrently(&self, requests: Vec<ChunkRequest>) -> Vec<Vec<String>> {
257        self.csi_request(Function::Chunk, json!({"requests": requests}))
258            .unwrap()
259    }
260
261    fn select_language_concurrently(
262        &self,
263        requests: Vec<SelectLanguageRequest>,
264    ) -> Vec<Option<LanguageCode>> {
265        self.csi_request(Function::SelectLanguage, json!({"requests": requests}))
266            .unwrap()
267    }
268
269    fn search_concurrently(&self, requests: Vec<SearchRequest>) -> Vec<Vec<SearchResult>> {
270        self.csi_request(Function::Search, json!({"requests": requests}))
271            .unwrap()
272    }
273
274    fn documents<Metadata>(
275        &self,
276        paths: Vec<DocumentPath>,
277    ) -> anyhow::Result<Vec<Document<Metadata>>>
278    where
279        Metadata: for<'a> Deserialize<'a> + Serialize,
280    {
281        Ok(self
282            .csi_request::<Vec<Document<Metadata>>>(
283                Function::Documents,
284                json!({"requests": paths}),
285            )?
286            .into_iter()
287            .collect())
288    }
289
290    fn documents_metadata<Metadata>(
291        &self,
292        paths: Vec<DocumentPath>,
293    ) -> anyhow::Result<Vec<Option<Metadata>>>
294    where
295        Metadata: for<'a> Deserialize<'a> + Serialize,
296    {
297        Ok(self
298            .csi_request::<Vec<Option<Metadata>>>(
299                Function::DocumentMetadata,
300                json!({"requests": paths}),
301            )?
302            .into_iter()
303            .collect())
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use jiff::Timestamp;
310    use pharia_skill::{
311        ChatParams, ChunkParams, ChunkRequest, CompletionParams, IndexPath, Modality,
312    };
313
314    use super::*;
315
316    #[test]
317    fn can_make_request() {
318        drop(dotenvy::dotenv());
319
320        let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
321        let csi = DevCsi::aleph_alpha(token);
322
323        let response = csi.complete(
324            CompletionRequest::new(
325                "llama-3.1-8b-instruct",
326                "<|begin_of_text|><|start_header_id|>system<|end_header_id|>
327
328Cutting Knowledge Date: December 2023
329Today Date: 23 Jul 2024
330
331You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
332
333What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>",
334            )
335            .with_params(CompletionParams {
336                stop: vec!["<|start_header_id|>".into()],
337                max_tokens: Some(10),
338                ..Default::default()
339            }),
340        );
341        assert_eq!(
342            response.text.trim(),
343            "The capital of France is Paris.<|eot_id|>"
344        );
345    }
346
347    #[test]
348    fn can_make_multiple_requests() {
349        drop(dotenvy::dotenv());
350
351        let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
352        let csi = DevCsi::aleph_alpha(token);
353
354        let params = CompletionParams {
355            stop: vec!["<|start_header_id|>".into()],
356            max_tokens: Some(10),
357            ..Default::default()
358        };
359        let completion_request = CompletionRequest::new(
360            "llama-3.1-8b-instruct",
361            "<|begin_of_text|><|start_header_id|>system<|end_header_id|>
362
363Cutting Knowledge Date: December 2023
364Today Date: 23 Jul 2024
365
366You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
367
368What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>",
369        )
370        .with_params(params);
371
372        let response = csi.complete_concurrently(vec![completion_request; 2]);
373        assert!(response
374            .into_iter()
375            .all(|r| r.text.trim() == "The capital of France is Paris.<|eot_id|>"));
376    }
377
378    #[test]
379    fn chunk() {
380        drop(dotenvy::dotenv());
381
382        let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
383        let csi = DevCsi::aleph_alpha(token);
384
385        let response = csi.chunk(ChunkRequest::new(
386            "123456",
387            ChunkParams::new("llama-3.1-8b-instruct", 1),
388        ));
389
390        assert_eq!(response, vec!["123", "456"]);
391    }
392
393    #[test]
394    fn select_language() {
395        drop(dotenvy::dotenv());
396
397        let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
398        let csi = DevCsi::aleph_alpha(token);
399
400        let response = csi.select_language(SelectLanguageRequest::new(
401            "A rising tide lifts all boats",
402            [LanguageCode::Eng, LanguageCode::Deu, LanguageCode::Fra],
403        ));
404
405        assert_eq!(response, Some(LanguageCode::Eng));
406    }
407
408    #[test]
409    fn search() {
410        drop(dotenvy::dotenv());
411
412        let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
413        let csi = DevCsi::aleph_alpha(token);
414
415        let response = csi.search(
416            SearchRequest::new("decoder", IndexPath::new("Kernel", "test", "asym-64"))
417                .with_max_results(10),
418        );
419
420        assert!(!response.is_empty());
421    }
422
423    #[test]
424    fn chat() {
425        drop(dotenvy::dotenv());
426
427        let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
428        let csi = DevCsi::aleph_alpha(token);
429
430        let request = ChatRequest::new(
431            "llama-3.1-8b-instruct",
432            Message::user("Hello, how are you?"),
433        )
434        .with_params(ChatParams {
435            max_tokens: Some(1),
436            ..Default::default()
437        });
438        let response = csi.chat(request);
439
440        assert!(!response.message.content.is_empty());
441    }
442
443    #[test]
444    fn documents() {
445        #[derive(Debug, Deserialize, Serialize)]
446        struct Metadata {
447            created: Timestamp,
448            url: String,
449        }
450
451        drop(dotenvy::dotenv());
452
453        let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
454        let csi = DevCsi::aleph_alpha(token);
455
456        let path = DocumentPath::new("Kernel", "test", "kernel-docs");
457        let response = csi.document::<Metadata>(path.clone()).unwrap();
458
459        assert_eq!(response.path, path);
460        assert_eq!(response.contents.len(), 1);
461        assert!(
462            matches!(&response.contents[0], Modality::Text { text } if text.contains("Kernel"))
463        );
464    }
465
466    #[test]
467    fn document_metadata() {
468        #[derive(Debug, Deserialize, Serialize)]
469        struct Metadata {
470            created: Timestamp,
471            url: String,
472        }
473
474        drop(dotenvy::dotenv());
475
476        let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
477        let csi = DevCsi::aleph_alpha(token);
478
479        let path = DocumentPath::new("Kernel", "test", "kernel-docs");
480        let response = csi.document_metadata::<Metadata>(path.clone()).unwrap();
481
482        assert!(response.is_some());
483    }
484
485    #[test]
486    fn invalid_metadata() {
487        drop(dotenvy::dotenv());
488
489        let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
490        let csi = DevCsi::aleph_alpha(token);
491
492        let path = DocumentPath::new("Kernel", "test", "kernel-docs");
493        let response = csi.document_metadata::<String>(path.clone());
494
495        assert!(response.is_err());
496    }
497}