1use clawedcode_mcp::McpServerConfig;
2use clawedcode_tools::ToolSpec;
3use futures_core::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::BTreeMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::time::Duration;
9use tokio::sync::mpsc;
10
11#[derive(Debug, Clone)]
14pub struct CompletionRequest {
15 pub model: String,
16 pub prompt_pack: String,
17 pub system_prompt_name: String,
18 pub system_prompt_body: String,
19 pub prompt: String,
20 pub messages: Vec<ProviderMessage>,
22 pub tools: Vec<ToolSpec>,
23 pub skill_count: usize,
24 pub mcp_servers: BTreeMap<String, McpServerConfig>,
25}
26
27#[derive(Debug, Clone, Serialize)]
28pub struct CompletionResponse {
29 pub system_prompt: String,
30 pub response: String,
31 pub tool_count: usize,
32 pub skill_count: usize,
33 pub mcp_server_count: usize,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum ProviderRole {
41 User,
42 Assistant,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46pub struct ProviderMessage {
47 pub role: ProviderRole,
48 pub content: Vec<ProviderContentBlock>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum ProviderContentBlock {
54 Text {
55 text: String,
56 },
57 ToolUse {
58 id: String,
59 name: String,
60 #[serde(default)]
61 input: serde_json::Value,
62 },
63 ToolResult {
64 tool_use_id: String,
65 content: String,
66 #[serde(default)]
67 is_error: bool,
68 },
69 Thinking {
70 thinking: String,
71 },
72}
73
74#[derive(Debug, Clone, Serialize)]
77pub struct ToolUseEvent {
78 pub id: String,
79 pub name: String,
80 pub input: String,
81}
82
83#[derive(Debug, Clone, Serialize)]
84pub struct ToolResultEvent {
85 pub tool_use_id: String,
86 pub content: String,
87 pub is_error: bool,
88}
89
90#[derive(Debug, Clone, Serialize)]
91pub struct UsageEvent {
92 pub input_tokens: u64,
93 pub output_tokens: u64,
94 pub cache_read_tokens: u64,
95 pub cache_write_tokens: u64,
96}
97
98#[derive(Debug, Clone, Serialize)]
99pub enum ApiEvent {
100 MessageDelta { text: String },
101 ThinkingDelta { text: String },
102 ToolUse { tool_use: ToolUseEvent },
103 ToolResult { tool_result: ToolResultEvent },
104 Usage { usage: UsageEvent },
105 Completed,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
111pub enum ProviderError {
112 Network { message: String },
113 Api { status: u16, message: String },
114 Parse { message: String },
115 Timeout { elapsed_ms: u64 },
116 RetryExhausted { attempts: u32, last_error: String },
117 Other { message: String },
118}
119
120impl std::fmt::Display for ProviderError {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 match self {
123 ProviderError::Network { message } => write!(f, "Network error: {message}"),
124 ProviderError::Api { status, message } => write!(f, "API error ({status}): {message}"),
125 ProviderError::Parse { message } => write!(f, "Parse error: {message}"),
126 ProviderError::Timeout { elapsed_ms } => write!(f, "Timeout after {elapsed_ms}ms"),
127 ProviderError::RetryExhausted {
128 attempts,
129 last_error,
130 } => {
131 write!(f, "Retry exhausted after {attempts} attempts: {last_error}")
132 }
133 ProviderError::Other { message } => write!(f, "{message}"),
134 }
135 }
136}
137
138impl std::error::Error for ProviderError {}
139
140#[derive(Debug, Clone, Default, Serialize)]
143pub struct UsageAccount {
144 pub total_input_tokens: u64,
145 pub total_output_tokens: u64,
146 pub total_cache_read_tokens: u64,
147 pub total_cache_write_tokens: u64,
148 pub request_count: u64,
149}
150
151impl UsageAccount {
152 pub fn record(&mut self, usage: &UsageEvent) {
153 self.total_input_tokens += usage.input_tokens;
154 self.total_output_tokens += usage.output_tokens;
155 self.total_cache_read_tokens += usage.cache_read_tokens;
156 self.total_cache_write_tokens += usage.cache_write_tokens;
157 self.request_count += 1;
158 }
159}
160
161#[derive(Debug, Clone)]
164pub struct RetryConfig {
165 pub max_attempts: u32,
166 pub base_delay: Duration,
167 pub max_delay: Duration,
168}
169
170impl Default for RetryConfig {
171 fn default() -> Self {
172 Self {
173 max_attempts: 3,
174 base_delay: Duration::from_millis(200),
175 max_delay: Duration::from_secs(5),
176 }
177 }
178}
179
180#[derive(Debug, Clone)]
181pub struct TimeoutConfig {
182 pub per_request: Duration,
183}
184
185impl Default for TimeoutConfig {
186 fn default() -> Self {
187 Self {
188 per_request: Duration::from_secs(60),
189 }
190 }
191}
192
193pub type EventStream = Pin<Box<dyn Stream<Item = Result<ApiEvent, ProviderError>> + Send>>;
196
197pub trait Provider: Send + Sync {
198 fn complete(
199 &self,
200 request: &CompletionRequest,
201 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>;
202
203 fn stream(&self, request: &CompletionRequest) -> EventStream;
204}
205
206pub type BoxedProvider = Box<dyn Provider>;
209
210#[derive(Debug, Default, Clone)]
213pub struct MockProvider;
214
215impl Provider for MockProvider {
216 fn complete(
217 &self,
218 request: &CompletionRequest,
219 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
220 let response = mock_complete_response(request);
221 Box::pin(async move { Ok(response) })
222 }
223
224 fn stream(&self, request: &CompletionRequest) -> EventStream {
225 let req = request.clone();
226 let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
227
228 tokio::spawn(async move {
229 let events = mock_stream_events(&req);
230 for (idx, event) in events.into_iter().enumerate() {
231 if idx > 0 {
232 tokio::time::sleep(Duration::from_millis(18)).await;
233 }
234 if tx.send(Ok(event)).await.is_err() {
235 return;
236 }
237 }
238 });
239
240 Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
241 rx.recv().await.map(|item| (item, rx))
242 }))
243 }
244}
245
246fn mock_complete_response(request: &CompletionRequest) -> CompletionResponse {
247 if wants_read_cargo_toml(&request.prompt) {
248 if let Some(tool_content) = first_tool_result_content(&request.messages) {
249 let response = mock_summarize_cargo_toml(&tool_content);
250 return CompletionResponse {
251 system_prompt: request.system_prompt_name.clone(),
252 response,
253 tool_count: request.tools.len(),
254 skill_count: request.skill_count,
255 mcp_server_count: request.mcp_servers.len(),
256 };
257 }
258
259 return CompletionResponse {
260 system_prompt: request.system_prompt_name.clone(),
261 response: "I'll read Cargo.toml first.".to_string(),
262 tool_count: request.tools.len(),
263 skill_count: request.skill_count,
264 mcp_server_count: request.mcp_servers.len(),
265 };
266 }
267
268 let response = format!(
269 "Model: {}\nPrompt pack: {}\nSystem prompt: {}\nTools: {}\nSkills discovered: {}\nMCP servers discovered: {}\n\nRequest queued for the execution loop.\n\nNext priorities:\n1. Parse instructions into an explicit task graph.\n2. Resolve tool approvals before execution.\n3. Stream structured updates into the terminal UI.",
270 request.model,
271 request.prompt_pack,
272 request.system_prompt_name,
273 request
274 .tools
275 .iter()
276 .map(|tool| tool.name)
277 .collect::<Vec<_>>()
278 .join(", "),
279 request.skill_count,
280 request.mcp_servers.len(),
281 );
282
283 CompletionResponse {
284 system_prompt: request.system_prompt_name.clone(),
285 response,
286 tool_count: request.tools.len(),
287 skill_count: request.skill_count,
288 mcp_server_count: request.mcp_servers.len(),
289 }
290}
291
292fn mock_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
293 if wants_read_cargo_toml(&request.prompt) {
294 if let Some(tool_content) = first_tool_result_content(&request.messages) {
295 return vec![
296 ApiEvent::ThinkingDelta {
297 text: "I have the Cargo.toml contents; summarizing.".to_string(),
298 },
299 ApiEvent::MessageDelta {
300 text: mock_summarize_cargo_toml(&tool_content),
301 },
302 ApiEvent::Usage {
303 usage: UsageEvent {
304 input_tokens: 220,
305 output_tokens: 90,
306 cache_read_tokens: 0,
307 cache_write_tokens: 0,
308 },
309 },
310 ApiEvent::Completed,
311 ];
312 }
313
314 return vec![
315 ApiEvent::ThinkingDelta {
316 text: "I should read Cargo.toml to answer this.".to_string(),
317 },
318 ApiEvent::ToolUse {
319 tool_use: ToolUseEvent {
320 id: "tool_1".to_string(),
321 name: "read_file".to_string(),
322 input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
323 },
324 },
325 ApiEvent::Completed,
326 ];
327 }
328
329 let tool_names: Vec<&str> = request.tools.iter().map(|t| t.name.as_ref()).collect();
330
331 vec![
332 ApiEvent::ThinkingDelta {
333 text: "Let me start by analyzing the request.".to_string(),
334 },
335 ApiEvent::MessageDelta {
336 text: format!(
337 "Model: {}\nPrompt pack: {}\nSystem prompt: {}\nTools: {}\nSkills discovered: {}\nMCP servers discovered: {}\n",
338 request.model,
339 request.prompt_pack,
340 request.system_prompt_name,
341 tool_names.join(", "),
342 request.skill_count,
343 request.mcp_servers.len(),
344 ),
345 },
346 ApiEvent::MessageDelta {
347 text: "\nRequest queued for the execution loop.\n\nNext priorities:\n".to_string(),
348 },
349 ApiEvent::MessageDelta {
350 text: "1. Parse instructions into an explicit task graph.\n".to_string(),
351 },
352 ApiEvent::MessageDelta {
353 text: "2. Resolve tool approvals before execution.\n".to_string(),
354 },
355 ApiEvent::MessageDelta {
356 text: "3. Stream structured updates into the terminal UI.".to_string(),
357 },
358 ApiEvent::Usage {
359 usage: UsageEvent {
360 input_tokens: 120,
361 output_tokens: 85,
362 cache_read_tokens: 0,
363 cache_write_tokens: 0,
364 },
365 },
366 ApiEvent::Completed,
367 ]
368}
369
370fn wants_read_cargo_toml(prompt: &str) -> bool {
371 let p = prompt.to_ascii_lowercase();
372 p.contains("cargo.toml") && (p.contains("read") || p.contains("summarize"))
373}
374
375fn first_tool_result_content(messages: &[ProviderMessage]) -> Option<String> {
376 for m in messages {
377 for b in &m.content {
378 if let ProviderContentBlock::ToolResult { content, .. } = b {
379 return Some(content.clone());
380 }
381 }
382 }
383 None
384}
385
386fn mock_summarize_cargo_toml(contents: &str) -> String {
387 if contents.contains("[workspace]") {
388 let mut out = String::from("Cargo.toml defines a Rust workspace.\n");
389 if contents.contains("members") {
390 out.push_str("It declares workspace members; this repo is a multi-crate workspace.\n");
391 }
392 out.push_str("Key crates include: clawedcode (cli), clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, clawedcode-tui.");
393 out
394 } else {
395 "Cargo.toml does not look like a workspace manifest (no [workspace] section).".to_string()
396 }
397}
398
399#[derive(Debug, Default, Clone)]
410pub struct MockToolProvider;
411
412impl Provider for MockToolProvider {
413 fn complete(
414 &self,
415 request: &CompletionRequest,
416 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
417 let response = mock_tool_complete_response(request);
418 Box::pin(async move { Ok(response) })
419 }
420
421 fn stream(&self, request: &CompletionRequest) -> EventStream {
422 let req = request.clone();
423 let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
424
425 tokio::spawn(async move {
426 let events = mock_tool_stream_events(&req);
427 for (idx, event) in events.into_iter().enumerate() {
428 if idx > 0 {
429 tokio::time::sleep(Duration::from_millis(5)).await;
430 }
431 if tx.send(Ok(event)).await.is_err() {
432 return;
433 }
434 }
435 });
436
437 Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
438 rx.recv().await.map(|item| (item, rx))
439 }))
440 }
441}
442
443fn mock_tool_complete_response(request: &CompletionRequest) -> CompletionResponse {
444 if has_tool_result_message(request) {
445 let response = "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string();
446 return CompletionResponse {
447 system_prompt: request.system_prompt_name.clone(),
448 response,
449 tool_count: request.tools.len(),
450 skill_count: request.skill_count,
451 mcp_server_count: request.mcp_servers.len(),
452 };
453 }
454
455 CompletionResponse {
456 system_prompt: request.system_prompt_name.clone(),
457 response: "I'll read the Cargo.toml file.".to_string(),
458 tool_count: request.tools.len(),
459 skill_count: request.skill_count,
460 mcp_server_count: request.mcp_servers.len(),
461 }
462}
463
464fn mock_tool_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
465 if has_tool_result_message(request) {
466 return vec![
467 ApiEvent::ThinkingDelta {
468 text: "I have the file contents now.".to_string(),
469 },
470 ApiEvent::MessageDelta {
471 text: "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string(),
472 },
473 ApiEvent::Usage {
474 usage: UsageEvent {
475 input_tokens: 200,
476 output_tokens: 60,
477 cache_read_tokens: 0,
478 cache_write_tokens: 0,
479 },
480 },
481 ApiEvent::Completed,
482 ];
483 }
484
485 vec![
486 ApiEvent::ThinkingDelta {
487 text: "I should read the Cargo.toml file.".to_string(),
488 },
489 ApiEvent::ToolUse {
490 tool_use: ToolUseEvent {
491 id: "tool_1".to_string(),
492 name: "read_file".to_string(),
493 input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
494 },
495 },
496 ApiEvent::Usage {
497 usage: UsageEvent {
498 input_tokens: 100,
499 output_tokens: 30,
500 cache_read_tokens: 0,
501 cache_write_tokens: 0,
502 },
503 },
504 ApiEvent::Completed,
505 ]
506}
507
508fn has_tool_result_message(request: &CompletionRequest) -> bool {
511 request.messages.iter().any(|m| {
512 m.content
513 .iter()
514 .any(|b| matches!(b, ProviderContentBlock::ToolResult { .. }))
515 })
516}
517
518#[cfg(feature = "anthropic")]
521pub mod anthropic_provider {
522 use super::*;
523 use reqwest::Client;
524
525 #[derive(Debug, Clone)]
526 pub struct AnthropicProvider {
527 client: Client,
528 api_key: String,
529 endpoint: String,
530 anthropic_version: String,
531 retry: RetryConfig,
532 timeout: TimeoutConfig,
533 }
534
535 impl AnthropicProvider {
536 pub fn from_env() -> Option<Self> {
537 let api_key = std::env::var("ANTHROPIC_API_KEY").ok()?;
538 let endpoint = std::env::var("CLAWEDCODE_ANTHROPIC_ENDPOINT")
539 .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
540 .unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string());
541 let anthropic_version = std::env::var("CLAWEDCODE_ANTHROPIC_VERSION")
542 .unwrap_or_else(|_| "2023-06-01".to_string());
543 Some(Self {
544 client: Client::new(),
545 api_key,
546 endpoint,
547 anthropic_version,
548 retry: RetryConfig::default(),
549 timeout: TimeoutConfig::default(),
550 })
551 }
552
553 pub fn new(api_key: String, endpoint: String) -> Self {
554 Self {
555 client: Client::new(),
556 api_key,
557 endpoint,
558 anthropic_version: "2023-06-01".to_string(),
559 retry: RetryConfig::default(),
560 timeout: TimeoutConfig::default(),
561 }
562 }
563 }
564
565 fn build_anthropic_tools(tools: &[ToolSpec]) -> serde_json::Value {
566 if tools.is_empty() {
567 return serde_json::Value::Null;
568 }
569 serde_json::Value::Array(
570 tools
571 .iter()
572 .map(|t| {
573 serde_json::json!({
574 "name": t.name,
575 "description": t.description,
576 "input_schema": t.input_schema,
577 })
578 })
579 .collect(),
580 )
581 }
582
583 fn sanitize_messages_for_anthropic(messages: &[ProviderMessage]) -> Vec<ProviderMessage> {
584 messages
585 .iter()
586 .map(|m| ProviderMessage {
587 role: m.role.clone(),
588 content: m
589 .content
590 .iter()
591 .filter(|b| !matches!(b, ProviderContentBlock::Thinking { .. }))
592 .cloned()
593 .collect(),
594 })
595 .filter(|m| !m.content.is_empty())
596 .collect()
597 }
598
599 impl Provider for AnthropicProvider {
600 fn complete(
601 &self,
602 request: &CompletionRequest,
603 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>
604 {
605 let req = request.clone();
606 let api_key = self.api_key.clone();
607 let client = self.client.clone();
608 let endpoint = self.endpoint.clone();
609 let anthropic_version = self.anthropic_version.clone();
610 let retry = self.retry.clone();
611 let timeout = self.timeout.clone();
612 Box::pin(async move {
613 let tools = build_anthropic_tools(&req.tools);
614 let messages = sanitize_messages_for_anthropic(&req.messages);
615 let messages = if messages.is_empty() {
616 serde_json::json!([{"role": "user", "content": req.prompt}])
617 } else {
618 serde_json::to_value(messages).unwrap_or_else(
619 |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
620 )
621 };
622 let mut body = serde_json::json!({
623 "model": req.model,
624 "max_tokens": 4096,
625 "system": req.system_prompt_body,
626 "messages": messages,
627 });
628 if !tools.is_null() {
629 body["tools"] = tools;
630 }
631
632 let mut last_err: Option<ProviderError> = None;
633
634 for attempt in 1..=retry.max_attempts {
635 let send_fut = client
636 .post(&endpoint)
637 .header("x-api-key", &api_key)
638 .header("anthropic-version", &anthropic_version)
639 .header("content-type", "application/json")
640 .json(&body)
641 .send();
642
643 let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
644 Ok(Ok(r)) => r,
645 Ok(Err(e)) => {
646 last_err = Some(ProviderError::Network {
647 message: e.to_string(),
648 });
649 if attempt < retry.max_attempts {
650 let backoff_ms = (retry.base_delay.as_millis() as u64)
651 .saturating_mul(1u64 << (attempt - 1));
652 tokio::time::sleep(Duration::from_millis(
653 backoff_ms.min(retry.max_delay.as_millis() as u64),
654 ))
655 .await;
656 continue;
657 }
658 break;
659 }
660 Err(_) => {
661 last_err = Some(ProviderError::Timeout {
662 elapsed_ms: timeout.per_request.as_millis() as u64,
663 });
664 if attempt < retry.max_attempts {
665 tokio::time::sleep(retry.base_delay).await;
666 continue;
667 }
668 break;
669 }
670 };
671
672 let status = resp.status().as_u16();
673 if !resp.status().is_success() {
674 let text = resp.text().await.unwrap_or_default();
675 let err = ProviderError::Api {
676 status,
677 message: text,
678 };
679 last_err = Some(err);
680
681 let retryable = (500..=599).contains(&status);
683 if retryable && attempt < retry.max_attempts {
684 tokio::time::sleep(retry.base_delay).await;
685 continue;
686 }
687 break;
688 }
689
690 let json: serde_json::Value =
691 resp.json().await.map_err(|e| ProviderError::Parse {
692 message: e.to_string(),
693 })?;
694
695 let response = json["content"]
696 .as_array()
697 .map(|arr| {
698 arr.iter()
699 .filter_map(|block| block["text"].as_str())
700 .collect::<Vec<_>>()
701 .join("")
702 })
703 .unwrap_or_default();
704
705 return Ok(CompletionResponse {
706 system_prompt: req.system_prompt_name.clone(),
707 response,
708 tool_count: req.tools.len(),
709 skill_count: req.skill_count,
710 mcp_server_count: req.mcp_servers.len(),
711 });
712 }
713
714 Err(match last_err {
715 Some(e) => ProviderError::RetryExhausted {
716 attempts: retry.max_attempts,
717 last_error: e.to_string(),
718 },
719 None => ProviderError::Other {
720 message: "request failed".to_string(),
721 },
722 })
723 })
724 }
725
726 fn stream(&self, request: &CompletionRequest) -> EventStream {
727 let req = request.clone();
728 let api_key = self.api_key.clone();
729 let client = self.client.clone();
730 let endpoint = self.endpoint.clone();
731 let anthropic_version = self.anthropic_version.clone();
732 let timeout = self.timeout.clone();
733
734 let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(64);
735
736 tokio::spawn(async move {
737 let tools = build_anthropic_tools(&req.tools);
738 let messages = sanitize_messages_for_anthropic(&req.messages);
739 let messages = if messages.is_empty() {
740 serde_json::json!([{"role": "user", "content": req.prompt}])
741 } else {
742 serde_json::to_value(messages).unwrap_or_else(
743 |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
744 )
745 };
746 let mut body = serde_json::json!({
747 "model": req.model,
748 "max_tokens": 4096,
749 "system": req.system_prompt_body,
750 "messages": messages,
751 "stream": true,
752 });
753 if !tools.is_null() {
754 body["tools"] = tools;
755 }
756
757 let send_fut = client
758 .post(&endpoint)
759 .header("x-api-key", &api_key)
760 .header("anthropic-version", &anthropic_version)
761 .header("content-type", "application/json")
762 .json(&body)
763 .send();
764
765 let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
766 Ok(Ok(r)) => r,
767 Ok(Err(e)) => {
768 let _ = tx
769 .send(Err(ProviderError::Network {
770 message: e.to_string(),
771 }))
772 .await;
773 return;
774 }
775 Err(_) => {
776 let _ = tx
777 .send(Err(ProviderError::Timeout {
778 elapsed_ms: timeout.per_request.as_millis() as u64,
779 }))
780 .await;
781 return;
782 }
783 };
784
785 if !resp.status().is_success() {
786 let status = resp.status().as_u16();
787 let text = resp.text().await.unwrap_or_default();
788 let _ = tx
789 .send(Err(ProviderError::Api {
790 status,
791 message: text,
792 }))
793 .await;
794 return;
795 }
796
797 let mut buf = String::new();
798 let mut bytes = resp.bytes_stream();
799 use futures_util::StreamExt;
800
801 let mut current_tool_use: Option<(String, String, String)> = None;
802
803 while let Some(chunk) = bytes.next().await {
804 let chunk = match chunk {
805 Ok(c) => c,
806 Err(e) => {
807 let _ = tx
808 .send(Err(ProviderError::Network {
809 message: e.to_string(),
810 }))
811 .await;
812 return;
813 }
814 };
815
816 buf.push_str(&String::from_utf8_lossy(&chunk));
817
818 while let Some(idx) = buf.find("\n\n") {
820 let frame: String = buf.drain(..(idx + 2)).collect();
821 let mut data_lines = Vec::new();
822 for line in frame.lines() {
823 let line = line.trim();
824 if let Some(rest) = line.strip_prefix("data:") {
825 let payload = rest.trim();
826 if !payload.is_empty() {
827 data_lines.push(payload.to_string());
828 }
829 }
830 }
831
832 if data_lines.is_empty() {
833 continue;
834 }
835
836 let data = data_lines.join("\n");
837 if data == "[DONE]" {
838 continue;
839 }
840
841 let Ok(event) = serde_json::from_str::<serde_json::Value>(&data) else {
842 continue;
843 };
844 let typ = event["type"].as_str().unwrap_or("");
845 match typ {
846 "content_block_start" => {
847 let cb_type = event["content_block"]["type"].as_str().unwrap_or("");
848 if cb_type == "text" {
849 if let Some(t) = event["content_block"]["text"].as_str() {
850 if !t.is_empty() {
851 let _ = tx
852 .send(Ok(ApiEvent::MessageDelta {
853 text: t.to_string(),
854 }))
855 .await;
856 }
857 }
858 } else if cb_type == "tool_use" {
859 let id = event["content_block"]["id"]
860 .as_str()
861 .unwrap_or("")
862 .to_string();
863 let name = event["content_block"]["name"]
864 .as_str()
865 .unwrap_or("")
866 .to_string();
867 let input = event["content_block"]["input"].clone();
868 let input_str = if input.is_null() {
869 String::new()
870 } else {
871 serde_json::to_string(&input).unwrap_or_default()
872 };
873 current_tool_use = Some((id, name, input_str));
874 } else if cb_type == "thinking" {
875 if let Some(t) = event["content_block"]["thinking"].as_str() {
876 if !t.is_empty() {
877 let _ = tx
878 .send(Ok(ApiEvent::ThinkingDelta {
879 text: t.to_string(),
880 }))
881 .await;
882 }
883 }
884 }
885 }
886 "content_block_delta" => {
887 let delta_type = event["delta"]["type"].as_str().unwrap_or("");
888 match delta_type {
889 "text_delta" => {
890 if let Some(t) = event["delta"]["text"].as_str() {
891 let _ = tx
892 .send(Ok(ApiEvent::MessageDelta {
893 text: t.to_string(),
894 }))
895 .await;
896 }
897 }
898 "thinking_delta" => {
899 if let Some(t) = event["delta"]["thinking"].as_str() {
900 let _ = tx
901 .send(Ok(ApiEvent::ThinkingDelta {
902 text: t.to_string(),
903 }))
904 .await;
905 }
906 }
907 "input_json_delta" => {
908 if let Some(partial) =
909 event["delta"]["partial_json"].as_str()
910 {
911 if let Some((_id, _name, input_buf)) =
912 current_tool_use.as_mut()
913 {
914 input_buf.push_str(partial);
915 }
916 }
917 }
918 _ => {}
919 }
920 }
921 "content_block_stop" => {
922 if let Some((id, name, input)) = current_tool_use.take() {
923 let input = if input.trim().is_empty() {
924 "{}".to_string()
925 } else {
926 input
927 };
928 let _ = tx
929 .send(Ok(ApiEvent::ToolUse {
930 tool_use: ToolUseEvent { id, name, input },
931 }))
932 .await;
933 }
934 }
935 "message_delta" => {
936 if let Some(usage) = event.get("usage") {
937 let _ = tx
938 .send(Ok(ApiEvent::Usage {
939 usage: UsageEvent {
940 input_tokens: usage["input_tokens"]
941 .as_u64()
942 .unwrap_or(0),
943 output_tokens: usage["output_tokens"]
944 .as_u64()
945 .unwrap_or(0),
946 cache_read_tokens: usage["cache_read_input_tokens"]
947 .as_u64()
948 .unwrap_or(0),
949 cache_write_tokens:
950 usage["cache_creation_input_tokens"]
951 .as_u64()
952 .unwrap_or(0),
953 },
954 }))
955 .await;
956 }
957 }
958 "message_stop" => {
959 let _ = tx.send(Ok(ApiEvent::Completed)).await;
960 return;
961 }
962 _ => {}
963 }
964 }
965 }
966
967 let _ = tx.send(Ok(ApiEvent::Completed)).await;
968 });
969
970 Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
971 rx.recv().await.map(|item| (item, rx))
972 }))
973 }
974 }
975}
976
977pub fn create_provider() -> BoxedProvider {
980 let provider_name = std::env::var("CLAWEDCODE_PROVIDER").unwrap_or_default();
981
982 match provider_name.as_str() {
983 #[cfg(feature = "anthropic")]
984 "anthropic" => {
985 if let Some(p) = anthropic_provider::AnthropicProvider::from_env() {
986 tracing::info!("Using Anthropic provider");
987 return Box::new(p);
988 }
989 tracing::warn!("ANTHROPIC_API_KEY not set, falling back to mock provider");
990 }
991 #[cfg(not(feature = "anthropic"))]
992 "anthropic" => {
993 tracing::warn!(
994 "Anthropic provider requested but 'anthropic' feature not enabled, falling back to mock"
995 );
996 }
997 _ => {}
998 }
999
1000 tracing::info!("Using mock provider");
1001 Box::new(MockProvider)
1002}
1003
1004pub async fn collect_stream_to_response(
1007 stream: EventStream,
1008 request: &CompletionRequest,
1009) -> Result<CompletionResponse, ProviderError> {
1010 use futures_util::StreamExt;
1011 let mut text = String::new();
1012 let mut thinking = String::new();
1013
1014 let mut s = stream;
1015 while let Some(event) = s.next().await {
1016 match event? {
1017 ApiEvent::MessageDelta { text: t } => text.push_str(&t),
1018 ApiEvent::ThinkingDelta { text: t } => thinking.push_str(&t),
1019 ApiEvent::Usage { usage: _ } => {}
1020 ApiEvent::Completed => break,
1021 ApiEvent::ToolUse { .. } | ApiEvent::ToolResult { .. } => {}
1022 }
1023 }
1024
1025 Ok(CompletionResponse {
1026 system_prompt: request.system_prompt_name.clone(),
1027 response: text,
1028 tool_count: request.tools.len(),
1029 skill_count: request.skill_count,
1030 mcp_server_count: request.mcp_servers.len(),
1031 })
1032}
1033
1034#[cfg(test)]
1035mod tests {
1036 use super::*;
1037 use futures_util::StreamExt;
1038
1039 #[tokio::test]
1040 async fn mock_stream_has_multiple_deltas() {
1041 let provider = MockProvider;
1042 let request = CompletionRequest {
1043 model: "test-model".to_string(),
1044 prompt_pack: "default".to_string(),
1045 system_prompt_name: "default".to_string(),
1046 system_prompt_body: "You are helpful.".to_string(),
1047 prompt: "hello".to_string(),
1048 messages: vec![],
1049 tools: vec![],
1050 skill_count: 0,
1051 mcp_servers: BTreeMap::new(),
1052 };
1053
1054 let events: Vec<_> = provider
1055 .stream(&request)
1056 .filter_map(|e| async move { e.ok() })
1057 .collect()
1058 .await;
1059 assert!(!events.is_empty());
1060
1061 assert!(matches!(events.last(), Some(ApiEvent::Completed)));
1062 }
1063
1064 #[tokio::test]
1065 async fn mock_stream_event_ordering() {
1066 let provider = MockProvider;
1067 let request = CompletionRequest {
1068 model: "test-model".to_string(),
1069 prompt_pack: "default".to_string(),
1070 system_prompt_name: "default".to_string(),
1071 system_prompt_body: "You are helpful.".to_string(),
1072 prompt: "hello".to_string(),
1073 messages: vec![],
1074 tools: vec![],
1075 skill_count: 0,
1076 mcp_servers: BTreeMap::new(),
1077 };
1078
1079 let events: Vec<_> = provider
1080 .stream(&request)
1081 .filter_map(|e| async move { e.ok() })
1082 .collect()
1083 .await;
1084
1085 assert!(matches!(events[0], ApiEvent::ThinkingDelta { .. }));
1086 assert!(matches!(events[1], ApiEvent::MessageDelta { .. }));
1087
1088 let usage_idx = events
1089 .iter()
1090 .position(|e| matches!(e, ApiEvent::Usage { .. }))
1091 .expect("Usage event should exist");
1092 let completed_idx = events
1093 .iter()
1094 .position(|e| matches!(e, ApiEvent::Completed))
1095 .expect("Completed event should exist");
1096 assert!(usage_idx < completed_idx);
1097 }
1098
1099 #[tokio::test]
1100 async fn mock_stream_concatenated_text_matches_complete_response() {
1101 let provider = MockProvider;
1102 let request = CompletionRequest {
1103 model: "test-model".to_string(),
1104 prompt_pack: "default".to_string(),
1105 system_prompt_name: "default".to_string(),
1106 system_prompt_body: "You are helpful.".to_string(),
1107 prompt: "hello".to_string(),
1108 messages: vec![],
1109 tools: vec![],
1110 skill_count: 0,
1111 mcp_servers: BTreeMap::new(),
1112 };
1113
1114 let direct = provider.complete(&request).await.unwrap();
1115 let events: Vec<_> = provider
1116 .stream(&request)
1117 .filter_map(|e| async move { e.ok() })
1118 .collect()
1119 .await;
1120
1121 let mut text = String::new();
1122 for e in &events {
1123 if let ApiEvent::MessageDelta { text: t } = e {
1124 text.push_str(t);
1125 }
1126 }
1127
1128 assert_eq!(text, direct.response);
1129 }
1130
1131 #[test]
1132 fn usage_account_accumulates() {
1133 let mut account = UsageAccount::default();
1134 account.record(&UsageEvent {
1135 input_tokens: 100,
1136 output_tokens: 50,
1137 cache_read_tokens: 10,
1138 cache_write_tokens: 20,
1139 });
1140 account.record(&UsageEvent {
1141 input_tokens: 200,
1142 output_tokens: 75,
1143 cache_read_tokens: 0,
1144 cache_write_tokens: 0,
1145 });
1146
1147 assert_eq!(account.total_input_tokens, 300);
1148 assert_eq!(account.total_output_tokens, 125);
1149 assert_eq!(account.total_cache_read_tokens, 10);
1150 assert_eq!(account.total_cache_write_tokens, 20);
1151 assert_eq!(account.request_count, 2);
1152 }
1153
1154 #[test]
1155 fn provider_error_display() {
1156 let err = ProviderError::Timeout { elapsed_ms: 5000 };
1157 assert!(err.to_string().contains("5000"));
1158
1159 let err = ProviderError::RetryExhausted {
1160 attempts: 3,
1161 last_error: "timeout".to_string(),
1162 };
1163 assert!(err.to_string().contains("3"));
1164 }
1165}