rag_toolchain/clients/
traits.rs1use crate::common::{Chunk, Chunks, Embedding};
2use std::error::Error;
3use std::future::Future;
4
5use super::types::PromptMessage;
6
7pub trait AsyncEmbeddingClient {
10 type ErrorType: Error;
11 fn generate_embedding(
12 &self,
13 text: Chunk,
14 ) -> impl Future<Output = Result<Embedding, Self::ErrorType>> + Send;
15 fn generate_embeddings(
16 &self,
17 text: Chunks,
18 ) -> impl Future<Output = Result<Vec<Embedding>, Self::ErrorType>> + Send;
19}
20
21pub trait AsyncChatClient {
24 type ErrorType: Error;
25 fn invoke(
26 &self,
27 prompt_messages: Vec<PromptMessage>,
28 ) -> impl Future<Output = Result<PromptMessage, Self::ErrorType>> + Send;
29}
30
31pub trait AsyncStreamedChatClient {
34 type ErrorType: Error;
35 type Item: ChatCompletionStream;
36 fn invoke_stream(
37 &self,
38 prompt_messages: Vec<PromptMessage>,
39 ) -> impl Future<Output = Result<Self::Item, Self::ErrorType>> + Send;
40}
41
42pub trait ChatCompletionStream {
46 type ErrorType: Error;
47 type Item;
48 fn next(&mut self) -> impl Future<Output = Option<Result<Self::Item, Self::ErrorType>>>;
49}
50
51#[cfg(test)]
52use mockall::*;
53
54#[cfg(test)]
55mock! {
56 pub AsyncChatClient {}
57 impl AsyncChatClient for AsyncChatClient {
58 type ErrorType = std::io::Error;
59 async fn invoke(
60 &self,
61 prompt_messages: Vec<PromptMessage>,
62 ) -> Result<PromptMessage, <Self as AsyncChatClient>::ErrorType>;
63 }
64}
65
66#[cfg(test)]
67mock! {
68 #[derive(Copy)]
69 pub AsyncStreamedChatClient {}
70 impl AsyncStreamedChatClient for AsyncStreamedChatClient {
71 type ErrorType = std::io::Error;
72 type Item = MockChatCompletionStream;
73 async fn invoke_stream(
74 &self,
75 prompt_messages: Vec<PromptMessage>,
76 ) -> Result<MockChatCompletionStream, <Self as AsyncStreamedChatClient>::ErrorType>;
77 }
78}
79
80#[cfg(test)]
81mock! {
82 pub ChatCompletionStream {}
83 impl ChatCompletionStream for ChatCompletionStream {
84 type ErrorType = std::io::Error;
85 type Item = PromptMessage;
86 async fn next(&mut self) -> Option<Result<PromptMessage, <Self as ChatCompletionStream>::ErrorType>>;
87 }
88}