Skip to main content

chat_core/
traits.rs

1#[cfg(feature = "stream")]
2use crate::error::ChatError;
3#[cfg(feature = "stream")]
4use crate::types::response::StreamEvent;
5use crate::{
6    error::ChatFailure,
7    types::{
8        messages::Messages,
9        options::ChatOptions,
10        provider_meta::ProviderMeta,
11        response::{ChatResponse, EmbeddingsResponse},
12    },
13};
14use async_trait::async_trait;
15#[cfg(feature = "stream")]
16use futures::stream::BoxStream;
17
18use tools_rs::ToolCollection;
19
20#[async_trait]
21pub trait CompletionProvider: Send + Sync {
22    async fn complete(
23        &mut self,
24        messages: &mut Messages,
25        tools: Option<&ToolCollection>,
26        options: Option<&ChatOptions>,
27        structured_output: Option<&schemars::Schema>,
28    ) -> Result<ChatResponse, ChatFailure>;
29
30    fn metadata(&self) -> Option<&ProviderMeta> {
31        None
32    }
33}
34
35#[cfg(feature = "stream")]
36#[async_trait]
37pub trait StreamProvider: Send + Sync {
38    async fn stream(
39        &mut self,
40        messages: &mut Messages,
41        tools: Option<&ToolCollection>,
42        options: Option<&ChatOptions>,
43    ) -> Result<BoxStream<'static, Result<StreamEvent, ChatError>>, ChatError>;
44
45    /// Called after the stream has been fully consumed with the final response.
46    /// Providers can override this to store state from the completed stream.
47    fn on_stream_done(&mut self, _response: &ChatResponse) {}
48}
49
50/// Combined supertrait for providers that support both completion and streaming.
51/// All providers that implement both `CompletionProvider` and `StreamProvider`
52/// automatically implement this trait via the blanket impl.
53#[cfg(feature = "stream")]
54pub trait ChatProvider: CompletionProvider + StreamProvider {}
55
56#[cfg(feature = "stream")]
57impl<T: CompletionProvider + StreamProvider> ChatProvider for T {}
58
59#[async_trait]
60pub trait EmbeddingsProvider: Send + Sync {
61    async fn embed(&self, messages: &mut Messages) -> Result<EmbeddingsResponse, ChatFailure>;
62}