rainy_sdk/endpoints/chat.rs
1use crate::client::RainyClient;
2use crate::error::{RainyError, Result};
3use crate::models::{ChatCompletionRequest, ChatCompletionResponse, ChatCompletionStreamResponse};
4use futures::Stream;
5use std::pin::Pin;
6
7impl RainyClient {
8 /// Create a chat completion
9 ///
10 /// This endpoint sends a chat completion request to the Rainy API.
11 ///
12 /// # Arguments
13 ///
14 /// * `request` - The chat completion request parameters
15 ///
16 /// # Returns
17 ///
18 /// Returns the chat completion response from the AI model.
19 ///
20 /// # Example
21 ///
22 /// ```rust,no_run
23 /// # use rainy_sdk::{RainyClient, ChatCompletionRequest, ChatMessage, MessageRole};
24 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
25 /// let client = RainyClient::with_api_key("user-api-key")?;
26 ///
27 /// let messages = vec![
28 /// ChatMessage::user("Hello, how are you?"),
29 /// ];
30 ///
31 /// let request = ChatCompletionRequest::new("gemini-pro", messages)
32 /// .with_max_tokens(150)
33 /// .with_temperature(0.7);
34 ///
35 /// let response = client.create_chat_completion(request).await?;
36 ///
37 /// if let Some(choice) = response.choices.first() {
38 /// println!("Response: {}", choice.message.content);
39 /// }
40 /// # Ok(())
41 /// # }
42 /// ```
43 pub async fn create_chat_completion(
44 &self,
45 request: ChatCompletionRequest,
46 ) -> Result<ChatCompletionResponse> {
47 let body = serde_json::to_value(request)?;
48 self.make_request(reqwest::Method::POST, "/chat/completions", Some(body))
49 .await
50 }
51
52 /// Create a chat completion with streaming
53 ///
54 /// This method provides streaming support for chat completions.
55 ///
56 /// # Arguments
57 ///
58 /// * `request` - The chat completion request parameters
59 ///
60 /// # Returns
61 ///
62 /// Returns a stream of chat completion responses.
63 ///
64 /// # Example
65 ///
66 /// ```rust,no_run
67 /// # use rainy_sdk::{RainyClient, ChatCompletionRequest, ChatMessage, MessageRole};
68 /// # use futures::StreamExt;
69 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
70 /// let client = RainyClient::with_api_key("user-api-key")?;
71 ///
72 /// let messages = vec![
73 /// ChatMessage::user("Tell me a story"),
74 /// ];
75 ///
76 /// let request = ChatCompletionRequest::new("llama-3.1-8b-instant", messages)
77 /// .with_max_tokens(500)
78 /// .with_temperature(0.8)
79 /// .with_stream(true);
80 ///
81 /// let mut stream = client.create_chat_completion_stream(request).await?;
82 ///
83 /// while let Some(chunk) = stream.next().await {
84 /// match chunk {
85 /// Ok(response) => {
86 /// if let Some(choice) = response.choices.first() {
87 /// if let Some(content) = &choice.delta.content {
88 /// print!("{}", content);
89 /// }
90 /// }
91 /// }
92 /// Err(e) => eprintln!("Error: {}", e),
93 /// }
94 /// }
95 /// # Ok(())
96 /// # }
97 /// ```
98 pub async fn create_chat_completion_stream(
99 &self,
100 request: ChatCompletionRequest,
101 ) -> Result<Pin<Box<dyn Stream<Item = Result<ChatCompletionStreamResponse>> + Send>>> {
102 use eventsource_stream::Eventsource;
103 use futures::StreamExt;
104
105 let mut request_with_stream = request;
106 request_with_stream.stream = Some(true);
107
108 let url = format!("{}/api/v1/chat/completions", self.auth_config().base_url);
109 let headers = self.auth_config().build_headers()?;
110
111 let response = self
112 .http_client()
113 .post(&url)
114 .headers(headers)
115 .json(&request_with_stream)
116 .send()
117 .await?;
118
119 if !response.status().is_success() {
120 return Err(self
121 .handle_response::<ChatCompletionStreamResponse>(response)
122 .await
123 .err()
124 .unwrap());
125 }
126
127 let stream = response
128 .bytes_stream()
129 .eventsource()
130 .filter_map(|event| async move {
131 match event {
132 Ok(event) => {
133 // Handle the [DONE] marker
134 if event.data.trim() == "[DONE]" {
135 return None;
136 }
137
138 // Parse the JSON data
139 match serde_json::from_str::<ChatCompletionStreamResponse>(&event.data) {
140 Ok(response) => Some(Ok(response)),
141 Err(e) => Some(Err(RainyError::Serialization {
142 message: e.to_string(),
143 source_error: Some(e.to_string()),
144 })),
145 }
146 }
147 Err(e) => {
148 // Convert eventsource error to RainyError
149 Some(Err(RainyError::Network {
150 message: format!("SSE parsing error: {e}"),
151 retryable: true,
152 source_error: Some(e.to_string()),
153 }))
154 }
155 }
156 });
157
158 Ok(Box::pin(stream))
159 }
160}