1use crate::error::AiError;
4use crate::types::*;
5use async_trait::async_trait;
6use futures::Stream;
7use std::fmt::Debug;
8use std::sync::Arc;
9
10pub type ChatCompletionStream =
12 dyn Stream<Item = Result<ChatCompletionChunk, AiError>> + Send + Unpin;
13
14pub type TextStream = dyn Stream<Item = Result<TextChunk, AiError>> + Send + Unpin;
16
17pub type ObjectStream = dyn Stream<Item = Result<serde_json::Value, AiError>> + Send + Unpin;
19
20#[async_trait]
26pub trait Provider: Send + Sync + Debug + 'static {
27 fn info(&self) -> Arc<ProviderInfo>;
29
30 async fn chat_completion(
35 &self,
36 req: ChatCompletionRequest,
37 ) -> Result<ChatCompletionResponse, AiError>;
38
39 async fn stream_chat_completion(
43 &self,
44 req: ChatCompletionRequest,
45 ) -> Result<Box<ChatCompletionStream>, AiError>;
46}
47
48pub async fn collect_text_stream(
50 response: TextResponse,
51 mut stream: Box<TextStream>,
52) -> Result<TextResult, AiError> {
53 use futures::StreamExt;
54
55 let mut content = String::new();
56 let mut finish_reason = None;
57 let mut usage = None;
58 let tool_calls = None;
59
60 while let Some(chunk) = stream.next().await {
61 let chunk = chunk?;
62 content.push_str(&chunk.delta);
63
64 if let Some(reason) = chunk.finish_reason {
65 finish_reason = Some(reason);
66 }
67
68 if let Some(u) = chunk.usage {
69 usage = Some(u);
70 }
71 }
72
73 Ok(TextResult {
74 content,
75 finish_reason: finish_reason.unwrap_or(FinishReason::Stop),
76 usage: usage.unwrap_or(Usage {
77 prompt_tokens: 0,
78 completion_tokens: 0,
79 total_tokens: 0,
80 }),
81 model: response.model,
82 tool_calls,
83 })
84}