llm_link/llm/
stream.rs

1use super::Client;
2use anyhow::{anyhow, Result};
3use llm_connector::{types::ChatRequest, StreamFormat};
4use tokio_stream::wrappers::UnboundedReceiverStream;
5
6impl Client {
7    /// Send a streaming chat request with specified format (Ollama-style response)
8    /// 
9    /// This method returns streaming responses in Ollama API format, which is used by
10    /// Ollama-compatible clients like Zed.dev.
11    pub async fn chat_stream_with_format(
12        &self,
13        model: &str,
14        messages: Vec<llm_connector::types::Message>,
15        format: StreamFormat,
16    ) -> Result<UnboundedReceiverStream<String>> {
17        use futures_util::StreamExt;
18
19        // Messages are already in llm-connector format
20        let request = ChatRequest {
21            model: model.to_string(),
22            messages,
23            stream: Some(true),
24            ..Default::default()
25        };
26
27        tracing::info!("🔄 Requesting streaming from LLM connector...");
28
29        // Use real streaming API
30        let mut stream = self.llm_client.chat_stream(&request).await
31            .map_err(|e| anyhow!("LLM connector streaming error: {}", e))?;
32
33        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
34        let model_name = model.to_string();
35
36        tokio::spawn(async move {
37            tracing::info!("🔄 Starting to process stream chunks (Ollama format)...");
38            let mut chunk_count = 0;
39
40            while let Some(chunk) = stream.next().await {
41                tracing::debug!("📥 Received raw chunk from stream");
42
43                match chunk {
44                    Ok(stream_chunk) => {
45                        tracing::debug!("✅ Chunk OK, checking for content...");
46
47                        // Check for content
48                        if let Some(content) = stream_chunk.get_content() {
49                            if !content.is_empty() {
50                                chunk_count += 1;
51                                tracing::info!("📦 Received chunk #{}: '{}' ({} chars)", chunk_count, content, content.len());
52
53                                // Build Ollama-format streaming response
54                                let response_chunk = serde_json::json!({
55                                    "model": &model_name,
56                                    "created_at": chrono::Utc::now().to_rfc3339(),
57                                    "message": {
58                                        "role": "assistant",
59                                        "content": content,
60                                        "images": null
61                                    },
62                                    "done": false
63                                });
64
65                                let formatted_data = match format {
66                                    StreamFormat::SSE => format!("data: {}\n\n", response_chunk),
67                                    StreamFormat::NDJSON => format!("{}\n", response_chunk),
68                                    StreamFormat::Json => response_chunk.to_string(),
69                                };
70
71                                if tx.send(formatted_data).is_err() {
72                                    tracing::warn!("⚠️ Failed to send chunk to receiver (client disconnected?)");
73                                    break;
74                                }
75                                tracing::debug!("✅ Sent chunk #{} to client", chunk_count);
76                            }
77                        } else {
78                            tracing::debug!("⚠️ Chunk has no content (likely metadata or finish chunk)");
79                        }
80                    }
81                    Err(e) => {
82                        tracing::error!("❌ Stream error: {:?}", e);
83                        break;
84                    }
85                }
86            }
87
88            tracing::info!("✅ Stream processing completed. Total chunks: {}", chunk_count);
89
90            // Send final message
91            let final_chunk = serde_json::json!({
92                "model": model_name,
93                "created_at": chrono::Utc::now().to_rfc3339(),
94                "message": {
95                    "role": "assistant",
96                    "content": ""
97                },
98                "done": true
99            });
100
101            let formatted_final = match format {
102                StreamFormat::SSE => format!("data: {}\n\n", final_chunk),
103                StreamFormat::NDJSON => format!("{}\n", final_chunk),
104                StreamFormat::Json => final_chunk.to_string(),
105            };
106            let _ = tx.send(formatted_final);
107            tracing::info!("🏁 Sent final chunk");
108        });
109
110        Ok(UnboundedReceiverStream::new(rx))
111    }
112
113    /// Send a streaming chat request for OpenAI API (OpenAI-style response)
114    /// 
115    /// This method returns streaming responses in OpenAI API format, which is used by
116    /// OpenAI-compatible clients like Codex CLI.
117    /// 
118    /// Key feature: Automatically corrects finish_reason from "stop" to "tool_calls"
119    /// when tool_calls are detected in the stream.
120    pub async fn chat_stream_openai(
121        &self,
122        model: &str,
123        messages: Vec<llm_connector::types::Message>,
124        tools: Option<Vec<llm_connector::types::Tool>>,
125        format: StreamFormat,
126    ) -> Result<UnboundedReceiverStream<String>> {
127        use futures_util::StreamExt;
128
129        // Messages are already in llm-connector format
130        let request = ChatRequest {
131            model: model.to_string(),
132            messages,
133            stream: Some(true),
134            tools,
135            ..Default::default()
136        };
137
138        tracing::info!("🔄 Requesting streaming from LLM connector...");
139
140        // Use real streaming API
141        let mut stream = self.llm_client.chat_stream(&request).await
142            .map_err(|e| anyhow!("LLM connector streaming error: {}", e))?;
143
144        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
145        let model_name = model.to_string();
146
147        tokio::spawn(async move {
148            tracing::info!("🔄 Starting to process stream chunks (OpenAI format)...");
149            let mut chunk_count = 0;
150            let mut has_tool_calls = false;  // Track if tool_calls detected
151
152            while let Some(chunk) = stream.next().await {
153                tracing::debug!("📥 Received raw chunk from stream");
154
155                match chunk {
156                    Ok(stream_chunk) => {
157                        tracing::debug!("✅ Chunk OK, checking for content or tool_calls...");
158
159                        // Build delta object
160                        let mut delta = serde_json::json!({});
161                        let mut has_data = false;
162
163                        // Check for content
164                        if let Some(content) = stream_chunk.get_content() {
165                            if !content.is_empty() {
166                                delta["content"] = serde_json::json!(content);
167                                has_data = true;
168                                chunk_count += 1;
169                                tracing::info!("📦 Received chunk #{}: '{}' ({} chars)", chunk_count, content, content.len());
170                            }
171                        }
172
173                        // Check for tool_calls (extract from choices[0].delta.tool_calls)
174                        if let Some(first_choice) = stream_chunk.choices.get(0) {
175                            if let Some(tool_calls) = &first_choice.delta.tool_calls {
176                                if let Ok(tool_calls_value) = serde_json::to_value(tool_calls) {
177                                    delta["tool_calls"] = tool_calls_value;
178                                    has_data = true;
179                                    has_tool_calls = true;  // Mark tool_calls detected
180                                    chunk_count += 1;
181                                    tracing::info!("🔧 Received chunk #{} with tool_calls: {} calls", chunk_count, tool_calls.len());
182                                }
183                            }
184                        }
185
186                        if has_data {
187                            // Build OpenAI-standard streaming response format
188                            let openai_chunk = serde_json::json!({
189                                "id": "chatcmpl-123",
190                                "object": "chat.completion.chunk",
191                                "created": chrono::Utc::now().timestamp(),
192                                "model": &model_name,
193                                "choices": [{
194                                    "index": 0,
195                                    "delta": delta,
196                                    "finish_reason": null
197                                }]
198                            });
199
200                            let formatted_data = match format {
201                                StreamFormat::SSE => format!("data: {}\n\n", openai_chunk),
202                                StreamFormat::NDJSON => format!("{}\n", openai_chunk),
203                                StreamFormat::Json => openai_chunk.to_string(),
204                            };
205
206                            // Send all chunks immediately (preserve streaming experience)
207                            if tx.send(formatted_data).is_err() {
208                                tracing::warn!("⚠️ Failed to send chunk to receiver (client disconnected?)");
209                                break;
210                            }
211                            tracing::debug!("✅ Sent chunk #{} to client", chunk_count);
212                        } else {
213                            tracing::debug!("⚠️ Chunk has no content or tool_calls (likely metadata or finish chunk)");
214                        }
215                    }
216                    Err(e) => {
217                        tracing::error!("❌ Stream error: {:?}", e);
218                        break;
219                    }
220                }
221            }
222
223            tracing::info!("✅ Stream processing completed. Total chunks: {}", chunk_count);
224
225            // Send final message at stream end
226            // 🎯 Key fix: If tool_calls detected, finish_reason should be "tool_calls" not "stop"
227            let finish_reason = if has_tool_calls {
228                tracing::info!("🎯 Setting finish_reason to 'tool_calls' (detected tool_calls in stream)");
229                "tool_calls"
230            } else {
231                "stop"
232            };
233
234            let final_chunk = serde_json::json!({
235                "id": "chatcmpl-123",
236                "object": "chat.completion.chunk",
237                "created": chrono::Utc::now().timestamp(),
238                "model": model_name,
239                "choices": [{
240                    "index": 0,
241                    "delta": {},
242                    "finish_reason": finish_reason
243                }]
244            });
245
246            let formatted_final = match format {
247                StreamFormat::SSE => format!("data: {}\n\ndata: [DONE]\n\n", final_chunk),
248                StreamFormat::NDJSON => format!("{}\n", final_chunk),
249                StreamFormat::Json => final_chunk.to_string(),
250            };
251            let _ = tx.send(formatted_final);
252            tracing::info!("🏁 Sent final chunk and [DONE] marker");
253        });
254
255        Ok(UnboundedReceiverStream::new(rx))
256    }
257}
258