Skip to main content

my_chatgpt/
response.rs

1use reqwest::Client;
2use reqwest::header::{AUTHORIZATION, CONTENT_TYPE};
3use serde::{Deserialize, Serialize};
4use futures_util::{TryStreamExt};
5use std::sync::{Arc, Mutex};
6
7#[derive(Serialize, Deserialize, Debug, Clone)]
8pub struct Message {
9    pub role: String,
10    pub content: String,
11}
12
13#[derive(Serialize, Debug)]
14struct Request {
15    model: String,
16    input: String,
17    instructions: String,
18    stream: bool,
19    tools: Vec<Tool>,
20    previous_response_id: Option<String>,
21}
22
23#[derive(Serialize, Deserialize, Debug, Clone)]
24pub struct Tool {
25    #[serde(rename = "type")]
26    pub tool_type: String,
27}
28
29#[derive(Deserialize, Debug, Clone, Serialize)]
30pub struct UsageInfo {
31    pub input_tokens: Option<u32>,
32    pub input_tokens_details: Option<InputTokensDetails>,
33    pub output_tokens: Option<u32>,
34    pub output_tokens_details: Option<OutputTokensDetails>,
35    pub total_tokens: Option<u32>,
36}
37
38#[derive(Deserialize, Debug, Clone, Serialize)]
39pub struct InputTokensDetails {
40    pub cached_tokens: Option<u32>,
41}
42
43#[derive(Deserialize, Debug, Clone, Serialize)]
44pub struct OutputTokensDetails {
45    pub reasoning_tokens: Option<u32>,
46}
47
48// Define a general error type that encapsulates all possible error cases
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub enum ResponseError {
51    RequestError(String),
52    ParseError(String),
53    NetworkError(String),
54    Unknown(String),
55}
56#[derive(Debug, Clone,  Deserialize, Serialize)]
57pub enum SendChatResult {
58    Ok(SendChatOK),  // (message, id)
59    Err(ResponseError),
60}
61
62#[derive(Debug, Clone, Deserialize, Serialize)]
63pub struct SendChatOK {
64    pub model: String,
65    pub message: String,
66    pub id: String,
67    pub usage: UsageInfo,
68    pub tools: Vec<Tool>,
69}
70
71pub async fn send_chat<F>(
72    instructions: &str, 
73    input: &str, 
74    api_key: &str, 
75    model: &str, 
76    handler: &F,
77    previous_response_id: Option<&str>, // if this is provided, the message will be a continuation of the previous message
78) -> SendChatResult
79where 
80    F: Fn(Option<&UsageInfo>, Option<&ResponseError>, Option<&str>) + Send + Sync + 'static
81{
82    let client: Client = Client::new();
83    let mut tools = vec![];
84    match model {
85        "gpt-4.1-nano" => {
86            tools = vec![];
87        },
88        "gpt-4.1-mini" => {
89            tools = vec![Tool { tool_type: "web_search_preview".to_string() }];
90        },
91        "gpt-4.1" => {
92            tools = vec![Tool { tool_type: "web_search_preview".to_string() }];
93        },
94        "gpt-4o-mini" => {
95            tools = vec![Tool { tool_type: "web_search_preview".to_string() }];
96        },
97        "gpt-4o-mini-search-preview" => {
98            tools = vec![Tool { tool_type: "web_search_preview".to_string() }];
99        },
100        "gpt-4o" => {
101            tools = vec![Tool { tool_type: "web_search_preview".to_string() }];
102        },
103        "gpt-4o-search-preview" => {
104            tools = vec![Tool { tool_type: "web_search_preview".to_string() }];
105        }
106        _ => {}
107    }
108
109    
110
111    // Initialize updated history
112    let body = Request {
113        model: model.to_string(),
114        instructions: instructions.to_string(),
115        input: input.to_string(),
116        stream:true,
117        tools: tools,
118        previous_response_id: match previous_response_id {
119            None => None,
120            Some(id) => Some(id.to_string())
121        },
122    };
123
124    println!("***body: {:?}", body);
125
126    let res = match client
127        .post("https://api.openai.com/v1/responses")
128        .header(AUTHORIZATION, format!("Bearer {}", api_key))
129        .header(CONTENT_TYPE, "application/json")
130        .json(&body)
131        .send()
132        .await {
133            Ok(res) => {
134                println!("***Response status: {}", res.status());
135                res
136            },
137            Err(e) => {
138                let err = ResponseError::NetworkError(e.to_string());
139                handler(None, Some(&err), None);
140                return SendChatResult::Err(err);
141            }
142        };
143
144    if !res.status().is_success() {
145        let error_text = match res.text().await {
146            Ok(text) => text,
147            Err(e) => format!("***Failed to get error response from openai: {}", e),
148        };
149        println!("***Error response: {}", error_text);
150        return SendChatResult::Err(ResponseError::NetworkError(format!("API request failed: {}", error_text)));
151    }
152
153    let final_response = Arc::new(Mutex::new(SendChatOK { 
154        message: String::new(), 
155        id: String::new(), 
156        model: model.to_string(), 
157        usage: UsageInfo { 
158            input_tokens: None, 
159            input_tokens_details: None, 
160            output_tokens: None, 
161            output_tokens_details: None, 
162            total_tokens: None 
163        }, 
164        tools: Vec::new() 
165    }));
166
167    let final_response_for_stream = Arc::clone(&final_response);
168    let stream = res.bytes_stream()
169        .map_err(|e| ResponseError::NetworkError(e.to_string()))
170        .try_filter_map(|chunk| async move {
171            // Convert to String to avoid lifetime issues
172            let text = std::str::from_utf8(&chunk)
173                .map_err(|e| ResponseError::ParseError(format!("Failed to parse UTF-8: {}", e)))?;
174            
175            // Filter and collect the lines that start with "data: "
176            let lines: Vec<String> = text.lines()
177                .filter(|line| line.starts_with("data: "))
178                .map(|line| line[6..].to_string())
179                .collect();
180
181            Ok(if lines.is_empty() { None } else { Some(lines) })
182        })
183        .try_for_each(move |lines| {
184            let final_response = Arc::clone(&final_response_for_stream);
185            async move {
186                for payload in lines {
187                    if let Ok(value) = serde_json::from_str::<serde_json::Value>(&payload) {
188                        match value.get("type").and_then(|t| t.as_str()) {
189                            Some("response.in_progress") => {
190                                println!("\n=== Response In Progress Event ===");
191                                if let Ok(pretty_json) = serde_json::to_string_pretty(&value) {
192                                    println!("{}", pretty_json);
193                                } else {
194                                    println!("Raw value: {:?}", value);
195                                }
196                                let id = value.get("response").and_then(|r| r.get("id")).and_then(|id| id.as_str()).unwrap_or("");
197                                if let Ok(mut response) = final_response.lock() {
198                                    response.id = id.to_string(); // in case the id is not provided in the response.completed event
199                                }
200                            },
201                            Some("response.output_text.delta") => {
202                                println!("\n=== Response Content delta Event ===");
203                                // Use if-let-chains to avoid nested if-lets
204                                if let Some(delta) = value.get("delta").and_then(|d| d.as_str()) {
205                                    println!("***delta: {:?}", delta);    
206                                    handler(None, None, Some(delta));
207                                    if let Err(e) = std::io::Write::flush(&mut std::io::stdout()) {
208                                        let err = ResponseError::Unknown(format!("Failed to flush stdout: {}", e));
209                                        handler(None, Some(&err), None);
210                                    }
211                                }
212                            },
213                            Some("response.output_text.done") => {
214                                println!("\n=== Response Output Text Done Event ===");
215                                if let Ok(pretty_json) = serde_json::to_string_pretty(&value) {
216                                    println!("{}", pretty_json);
217                                } else {
218                                    println!("Raw value: {:?}", value);
219                                }
220                            },
221                            Some("response.content_part.done") => {
222                                println!("\n=== Response Content Part Done Event ===");
223                                if let Ok(pretty_json) = serde_json::to_string_pretty(&value) {
224                                    println!("{}", pretty_json);
225                                } else {
226                                    println!("Raw value: {:?}", value);
227                                }
228                                // Update message directly from the part if available
229                                if let Some(message) = value.get("part")
230                                    .and_then(|part| part.get("text"))
231                                    .and_then(|text| text.as_str()) {
232                                    if let Ok(mut response) = final_response.lock() {
233                                        response.message = message.to_string();
234                                    }
235                                }
236                                println!("===============================\n");
237                            },
238                            Some("response.output_item.done") => {
239                                println!("\n=== Response Output Item Done Event ===");
240                                if let Ok(pretty_json) = serde_json::to_string_pretty(&value) {
241                                    println!("{}", pretty_json);
242                                } else {
243                                    println!("Raw value: {:?}", value);
244                                }
245                                println!("===============================\n");
246                            },
247                            Some("response.completed") => {
248                                println!("\n=== Response Completed Event ===");
249                                if let Ok(pretty_json) = serde_json::to_string_pretty(&value) {
250                                    println!("{}", pretty_json);
251                                } else {
252                                    println!("Raw value: {:?}", value);
253                                }
254                                
255                                if let Some(response_value) = value.get("response") {
256                                    if let Ok(mut response) = final_response.lock() {
257                                        // Update fields directly from response_value to avoid unnecessary allocations
258                                        if let Some(id) = response_value.get("id").and_then(|id| id.as_str()) {
259                                            response.id = id.to_string();
260                                        }
261                                        // Deserialize directly into the existing struct to avoid allocation
262                                        if let Some(usage) = response_value.get("usage") {
263                                            if let Ok(u) = serde_json::from_value(usage.clone()) {
264                                                response.usage = u;
265                                            }
266                                        }
267                                        if let Some(tools) = response_value.get("tools") {
268                                            if let Ok(t) = serde_json::from_value(tools.clone()) {
269                                                response.tools = t;
270                                            }
271                                        }
272                                        println!("final_response: {:?}", response);
273                                        println!("===============================\n");
274                                    }
275                                }
276                                
277                                // Return from the try_for_each closure to end streaming
278                                return Ok(());
279                            },
280                            _ => {
281                                println!("value is {:?}", value);
282                            }
283                        }
284                    }
285                }
286                Ok(())
287            }
288        })
289        .await;
290
291    match stream {
292        Ok(_) => {
293            // Use try_unwrap to avoid unnecessary clone
294            match Arc::try_unwrap(final_response) {
295                Ok(mutex) => match mutex.into_inner() {
296                    Ok(response) => SendChatResult::Ok(response),
297                    Err(_) => SendChatResult::Err(ResponseError::Unknown("Failed to extract final response from mutex".to_string()))
298                },
299                Err(_) => SendChatResult::Err(ResponseError::Unknown("Failed to extract final response from Arc".to_string()))
300            }
301        },
302        Err(e) => SendChatResult::Err(e)
303    }
304}