rag_toolchain/clients/
traits.rs

1use crate::common::{Chunk, Chunks, Embedding};
2use std::error::Error;
3use std::future::Future;
4
5use super::types::PromptMessage;
6
7/// # [`AsyncEmbeddingClient`]
8/// Trait for any client that generates embeddings asynchronously
9pub trait AsyncEmbeddingClient {
10    type ErrorType: Error;
11    fn generate_embedding(
12        &self,
13        text: Chunk,
14    ) -> impl Future<Output = Result<Embedding, Self::ErrorType>> + Send;
15    fn generate_embeddings(
16        &self,
17        text: Chunks,
18    ) -> impl Future<Output = Result<Vec<Embedding>, Self::ErrorType>> + Send;
19}
20
21/// # [`AsyncChatClient`]
22/// Trait for any client that generates chat completions asynchronously
23pub trait AsyncChatClient {
24    type ErrorType: Error;
25    fn invoke(
26        &self,
27        prompt_messages: Vec<PromptMessage>,
28    ) -> impl Future<Output = Result<PromptMessage, Self::ErrorType>> + Send;
29}
30
31/// # [`AsyncStreamedChatClient`]
32/// Trait for any client that generates streamed chat completions asynchronously
33pub trait AsyncStreamedChatClient {
34    type ErrorType: Error;
35    type Item: ChatCompletionStream;
36    fn invoke_stream(
37        &self,
38        prompt_messages: Vec<PromptMessage>,
39    ) -> impl Future<Output = Result<Self::Item, Self::ErrorType>> + Send;
40}
41
42/// # [`ChatCompletionStream`]
43///
44/// Trait for any stream that generates chat completions
45pub trait ChatCompletionStream {
46    type ErrorType: Error;
47    type Item;
48    fn next(&mut self) -> impl Future<Output = Option<Result<Self::Item, Self::ErrorType>>>;
49}
50
51#[cfg(test)]
52use mockall::*;
53
54#[cfg(test)]
55mock! {
56    pub AsyncChatClient {}
57    impl AsyncChatClient for AsyncChatClient {
58        type ErrorType = std::io::Error;
59        async fn invoke(
60            &self,
61            prompt_messages: Vec<PromptMessage>,
62        ) -> Result<PromptMessage, <Self as AsyncChatClient>::ErrorType>;
63    }
64}
65
66#[cfg(test)]
67mock! {
68    #[derive(Copy)]
69    pub AsyncStreamedChatClient {}
70    impl AsyncStreamedChatClient for AsyncStreamedChatClient {
71        type ErrorType = std::io::Error;
72        type Item = MockChatCompletionStream;
73        async fn invoke_stream(
74            &self,
75            prompt_messages: Vec<PromptMessage>,
76        ) -> Result<MockChatCompletionStream, <Self as AsyncStreamedChatClient>::ErrorType>;
77    }
78}
79
80#[cfg(test)]
81mock! {
82    pub ChatCompletionStream {}
83    impl ChatCompletionStream for ChatCompletionStream {
84        type ErrorType = std::io::Error;
85        type Item = PromptMessage;
86        async fn next(&mut self) -> Option<Result<PromptMessage, <Self as ChatCompletionStream>::ErrorType>>;
87    }
88}