1use std::time::Duration;
2
3use pharia_skill::{
4 ChatRequest, ChatResponse, ChunkRequest, Completion, CompletionRequest, Csi, Document,
5 DocumentPath, FinishReason, LanguageCode, Message, SearchRequest, SearchResult,
6 SelectLanguageRequest, TokenUsage,
7};
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9use ureq::{json, serde_json::Value, Agent, AgentBuilder};
10
11pub struct StubCsi;
12
13impl Csi for StubCsi {
14 fn chat_concurrently(&self, requests: Vec<ChatRequest>) -> Vec<ChatResponse> {
15 requests
16 .iter()
17 .map(|_| ChatResponse {
18 message: Message::new("user", ""),
19 finish_reason: FinishReason::Stop,
20 logprobs: vec![],
21 usage: TokenUsage {
22 prompt: 0,
23 completion: 0,
24 },
25 })
26 .collect()
27 }
28
29 fn complete_concurrently(&self, requests: Vec<CompletionRequest>) -> Vec<Completion> {
30 requests
31 .into_iter()
32 .map(|request| Completion {
33 text: request.prompt,
34 finish_reason: FinishReason::Stop,
35 logprobs: vec![],
36 usage: TokenUsage {
37 prompt: 0,
38 completion: 0,
39 },
40 })
41 .collect()
42 }
43
44 fn chunk_concurrently(&self, requests: Vec<ChunkRequest>) -> Vec<Vec<String>> {
45 requests
46 .into_iter()
47 .map(|request| vec![request.text])
48 .collect()
49 }
50
51 fn select_language_concurrently(
52 &self,
53 requests: Vec<SelectLanguageRequest>,
54 ) -> Vec<Option<LanguageCode>> {
55 requests.iter().map(|_| None).collect()
56 }
57
58 fn search_concurrently(&self, _requests: Vec<SearchRequest>) -> Vec<Vec<SearchResult>> {
59 vec![]
60 }
61
62 fn documents<Metadata>(
63 &self,
64 _paths: Vec<DocumentPath>,
65 ) -> anyhow::Result<Vec<Document<Metadata>>>
66 where
67 Metadata: for<'a> Deserialize<'a>,
68 {
69 Ok(vec![])
70 }
71
72 fn documents_metadata<Metadata>(
73 &self,
74 _paths: Vec<DocumentPath>,
75 ) -> anyhow::Result<Vec<Option<Metadata>>>
76 where
77 Metadata: for<'a> Deserialize<'a>,
78 {
79 Ok(vec![])
80 }
81}
82
83pub struct MockCsi {
84 response: String,
85}
86
87impl MockCsi {
88 #[must_use]
89 pub fn new(response: impl Into<String>) -> Self {
90 Self {
91 response: response.into(),
92 }
93 }
94}
95
96impl Csi for MockCsi {
97 fn chat_concurrently(&self, requests: Vec<ChatRequest>) -> Vec<ChatResponse> {
98 requests
99 .iter()
100 .map(|_| ChatResponse {
101 message: Message::new("user", self.response.clone()),
102 finish_reason: FinishReason::Stop,
103 logprobs: vec![],
104 usage: TokenUsage {
105 prompt: 0,
106 completion: 0,
107 },
108 })
109 .collect()
110 }
111
112 fn complete_concurrently(&self, requests: Vec<CompletionRequest>) -> Vec<Completion> {
113 requests
114 .iter()
115 .map(|_| Completion {
116 text: self.response.clone(),
117 finish_reason: FinishReason::Stop,
118 logprobs: vec![],
119 usage: TokenUsage {
120 prompt: 0,
121 completion: 0,
122 },
123 })
124 .collect()
125 }
126
127 fn chunk_concurrently(&self, requests: Vec<ChunkRequest>) -> Vec<Vec<String>> {
128 requests
129 .into_iter()
130 .map(|request| vec![request.text])
131 .collect()
132 }
133
134 fn select_language_concurrently(
135 &self,
136 requests: Vec<SelectLanguageRequest>,
137 ) -> Vec<Option<LanguageCode>> {
138 requests.iter().map(|_| None).collect()
139 }
140
141 fn search_concurrently(&self, _requests: Vec<SearchRequest>) -> Vec<Vec<SearchResult>> {
142 vec![]
143 }
144
145 fn documents<Metadata>(
146 &self,
147 _paths: Vec<DocumentPath>,
148 ) -> anyhow::Result<Vec<Document<Metadata>>>
149 where
150 Metadata: for<'a> Deserialize<'a>,
151 {
152 Ok(vec![])
153 }
154
155 fn documents_metadata<Metadata>(
156 &self,
157 _paths: Vec<DocumentPath>,
158 ) -> anyhow::Result<Vec<Option<Metadata>>>
159 where
160 Metadata: for<'a> Deserialize<'a>,
161 {
162 Ok(vec![])
163 }
164}
165
166#[derive(Copy, Clone, Debug, Serialize)]
167#[serde(rename_all = "snake_case")]
168enum Function {
169 Complete,
170 Chunk,
171 SelectLanguage,
172 Search,
173 Chat,
174 Documents,
175 DocumentMetadata,
176}
177
178#[derive(Serialize)]
179struct CsiRequest<'a, P: Serialize> {
180 version: &'a str,
181 function: Function,
182 #[serde(flatten)]
183 payload: P,
184}
185
186pub struct DevCsi {
188 address: String,
189 agent: Agent,
190 token: String,
191}
192
193impl DevCsi {
194 const VERSION: &str = "0.3";
196
197 #[must_use]
198 pub fn new(address: impl Into<String>, token: impl Into<String>) -> Self {
199 let agent = AgentBuilder::new()
200 .timeout(Duration::from_secs(60 * 5))
201 .build();
202 Self {
203 address: address.into(),
204 agent,
205 token: token.into(),
206 }
207 }
208
209 pub fn aleph_alpha(token: impl Into<String>) -> Self {
211 Self::new("https://pharia-kernel.product.pharia.com", token)
212 }
213
214 fn csi_request<R: DeserializeOwned>(
215 &self,
216 function: Function,
217 payload: impl Serialize,
218 ) -> anyhow::Result<R> {
219 let json = CsiRequest {
220 version: Self::VERSION,
221 function,
222 payload,
223 };
224 let response = self
225 .agent
226 .post(&format!("{}/csi", &self.address))
227 .set("Authorization", &format!("Bearer {}", self.token))
228 .send_json(json);
229
230 match response {
231 Ok(response) => Ok(response.into_json::<R>()?),
232 Err(ureq::Error::Status(status, response)) => {
233 panic!(
234 "Failed Request: Status {status} {}",
235 response.into_json::<Value>().unwrap_or_default()
236 );
237 }
238 Err(e) => {
239 panic!("{e}")
240 }
241 }
242 }
243}
244
245impl Csi for DevCsi {
246 fn chat_concurrently(&self, requests: Vec<ChatRequest>) -> Vec<ChatResponse> {
247 self.csi_request(Function::Chat, json!({"requests": requests}))
248 .unwrap()
249 }
250
251 fn complete_concurrently(&self, requests: Vec<CompletionRequest>) -> Vec<Completion> {
252 self.csi_request(Function::Complete, json!({"requests": requests}))
253 .unwrap()
254 }
255
256 fn chunk_concurrently(&self, requests: Vec<ChunkRequest>) -> Vec<Vec<String>> {
257 self.csi_request(Function::Chunk, json!({"requests": requests}))
258 .unwrap()
259 }
260
261 fn select_language_concurrently(
262 &self,
263 requests: Vec<SelectLanguageRequest>,
264 ) -> Vec<Option<LanguageCode>> {
265 self.csi_request(Function::SelectLanguage, json!({"requests": requests}))
266 .unwrap()
267 }
268
269 fn search_concurrently(&self, requests: Vec<SearchRequest>) -> Vec<Vec<SearchResult>> {
270 self.csi_request(Function::Search, json!({"requests": requests}))
271 .unwrap()
272 }
273
274 fn documents<Metadata>(
275 &self,
276 paths: Vec<DocumentPath>,
277 ) -> anyhow::Result<Vec<Document<Metadata>>>
278 where
279 Metadata: for<'a> Deserialize<'a> + Serialize,
280 {
281 Ok(self
282 .csi_request::<Vec<Document<Metadata>>>(
283 Function::Documents,
284 json!({"requests": paths}),
285 )?
286 .into_iter()
287 .collect())
288 }
289
290 fn documents_metadata<Metadata>(
291 &self,
292 paths: Vec<DocumentPath>,
293 ) -> anyhow::Result<Vec<Option<Metadata>>>
294 where
295 Metadata: for<'a> Deserialize<'a> + Serialize,
296 {
297 Ok(self
298 .csi_request::<Vec<Option<Metadata>>>(
299 Function::DocumentMetadata,
300 json!({"requests": paths}),
301 )?
302 .into_iter()
303 .collect())
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use jiff::Timestamp;
310 use pharia_skill::{
311 ChatParams, ChunkParams, ChunkRequest, CompletionParams, IndexPath, Modality,
312 };
313
314 use super::*;
315
316 #[test]
317 fn can_make_request() {
318 drop(dotenvy::dotenv());
319
320 let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
321 let csi = DevCsi::aleph_alpha(token);
322
323 let response = csi.complete(
324 CompletionRequest::new(
325 "llama-3.1-8b-instruct",
326 "<|begin_of_text|><|start_header_id|>system<|end_header_id|>
327
328Cutting Knowledge Date: December 2023
329Today Date: 23 Jul 2024
330
331You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
332
333What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>",
334 )
335 .with_params(CompletionParams {
336 stop: vec!["<|start_header_id|>".into()],
337 max_tokens: Some(10),
338 ..Default::default()
339 }),
340 );
341 assert_eq!(
342 response.text.trim(),
343 "The capital of France is Paris.<|eot_id|>"
344 );
345 }
346
347 #[test]
348 fn can_make_multiple_requests() {
349 drop(dotenvy::dotenv());
350
351 let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
352 let csi = DevCsi::aleph_alpha(token);
353
354 let params = CompletionParams {
355 stop: vec!["<|start_header_id|>".into()],
356 max_tokens: Some(10),
357 ..Default::default()
358 };
359 let completion_request = CompletionRequest::new(
360 "llama-3.1-8b-instruct",
361 "<|begin_of_text|><|start_header_id|>system<|end_header_id|>
362
363Cutting Knowledge Date: December 2023
364Today Date: 23 Jul 2024
365
366You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
367
368What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>",
369 )
370 .with_params(params);
371
372 let response = csi.complete_concurrently(vec![completion_request; 2]);
373 assert!(response
374 .into_iter()
375 .all(|r| r.text.trim() == "The capital of France is Paris.<|eot_id|>"));
376 }
377
378 #[test]
379 fn chunk() {
380 drop(dotenvy::dotenv());
381
382 let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
383 let csi = DevCsi::aleph_alpha(token);
384
385 let response = csi.chunk(ChunkRequest::new(
386 "123456",
387 ChunkParams::new("llama-3.1-8b-instruct", 1),
388 ));
389
390 assert_eq!(response, vec!["123", "456"]);
391 }
392
393 #[test]
394 fn select_language() {
395 drop(dotenvy::dotenv());
396
397 let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
398 let csi = DevCsi::aleph_alpha(token);
399
400 let response = csi.select_language(SelectLanguageRequest::new(
401 "A rising tide lifts all boats",
402 [LanguageCode::Eng, LanguageCode::Deu, LanguageCode::Fra],
403 ));
404
405 assert_eq!(response, Some(LanguageCode::Eng));
406 }
407
408 #[test]
409 fn search() {
410 drop(dotenvy::dotenv());
411
412 let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
413 let csi = DevCsi::aleph_alpha(token);
414
415 let response = csi.search(
416 SearchRequest::new("decoder", IndexPath::new("Kernel", "test", "asym-64"))
417 .with_max_results(10),
418 );
419
420 assert!(!response.is_empty());
421 }
422
423 #[test]
424 fn chat() {
425 drop(dotenvy::dotenv());
426
427 let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
428 let csi = DevCsi::aleph_alpha(token);
429
430 let request = ChatRequest::new(
431 "llama-3.1-8b-instruct",
432 Message::user("Hello, how are you?"),
433 )
434 .with_params(ChatParams {
435 max_tokens: Some(1),
436 ..Default::default()
437 });
438 let response = csi.chat(request);
439
440 assert!(!response.message.content.is_empty());
441 }
442
443 #[test]
444 fn documents() {
445 #[derive(Debug, Deserialize, Serialize)]
446 struct Metadata {
447 created: Timestamp,
448 url: String,
449 }
450
451 drop(dotenvy::dotenv());
452
453 let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
454 let csi = DevCsi::aleph_alpha(token);
455
456 let path = DocumentPath::new("Kernel", "test", "kernel-docs");
457 let response = csi.document::<Metadata>(path.clone()).unwrap();
458
459 assert_eq!(response.path, path);
460 assert_eq!(response.contents.len(), 1);
461 assert!(
462 matches!(&response.contents[0], Modality::Text { text } if text.contains("Kernel"))
463 );
464 }
465
466 #[test]
467 fn document_metadata() {
468 #[derive(Debug, Deserialize, Serialize)]
469 struct Metadata {
470 created: Timestamp,
471 url: String,
472 }
473
474 drop(dotenvy::dotenv());
475
476 let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
477 let csi = DevCsi::aleph_alpha(token);
478
479 let path = DocumentPath::new("Kernel", "test", "kernel-docs");
480 let response = csi.document_metadata::<Metadata>(path.clone()).unwrap();
481
482 assert!(response.is_some());
483 }
484
485 #[test]
486 fn invalid_metadata() {
487 drop(dotenvy::dotenv());
488
489 let token = std::env::var("PHARIA_AI_TOKEN").unwrap();
490 let csi = DevCsi::aleph_alpha(token);
491
492 let path = DocumentPath::new("Kernel", "test", "kernel-docs");
493 let response = csi.document_metadata::<String>(path.clone());
494
495 assert!(response.is_err());
496 }
497}