1use clawedcode_mcp::McpServerConfig;
2use clawedcode_tools::ToolSpec;
3use futures_core::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::BTreeMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::time::Duration;
9use tokio::sync::mpsc;
10
11#[derive(Debug, Clone)]
14pub struct CompletionRequest {
15 pub model: String,
16 pub prompt_pack: String,
17 pub system_prompt_name: String,
18 pub system_prompt_body: String,
19 pub prompt: String,
20 pub messages: Vec<ProviderMessage>,
22 pub tools: Vec<ToolSpec>,
23 pub skill_count: usize,
24 pub mcp_servers: BTreeMap<String, McpServerConfig>,
25}
26
27#[derive(Debug, Clone, Serialize)]
28pub struct CompletionResponse {
29 pub system_prompt: String,
30 pub response: String,
31 pub tool_count: usize,
32 pub skill_count: usize,
33 pub mcp_server_count: usize,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum ProviderRole {
41 User,
42 Assistant,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46pub struct ProviderMessage {
47 pub role: ProviderRole,
48 pub content: Vec<ProviderContentBlock>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum ProviderContentBlock {
54 Text {
55 text: String,
56 },
57 ToolUse {
58 id: String,
59 name: String,
60 #[serde(default)]
61 input: serde_json::Value,
62 },
63 ToolResult {
64 tool_use_id: String,
65 content: String,
66 #[serde(default)]
67 is_error: bool,
68 },
69 Thinking {
70 thinking: String,
71 },
72}
73
74#[derive(Debug, Clone, Serialize)]
77pub struct ToolUseEvent {
78 pub id: String,
79 pub name: String,
80 pub input: String,
81}
82
83#[derive(Debug, Clone, Serialize)]
84pub struct ToolResultEvent {
85 pub tool_use_id: String,
86 pub content: String,
87 pub is_error: bool,
88}
89
90#[derive(Debug, Clone, Serialize)]
91pub struct UsageEvent {
92 pub input_tokens: u64,
93 pub output_tokens: u64,
94 pub cache_read_tokens: u64,
95 pub cache_write_tokens: u64,
96}
97
98#[derive(Debug, Clone, Serialize)]
99pub enum ApiEvent {
100 MessageDelta { text: String },
101 ThinkingDelta { text: String },
102 ToolUse { tool_use: ToolUseEvent },
103 ToolResult { tool_result: ToolResultEvent },
104 Usage { usage: UsageEvent },
105 Completed,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
111pub enum ProviderError {
112 Network { message: String },
113 Api { status: u16, message: String },
114 Parse { message: String },
115 Timeout { elapsed_ms: u64 },
116 RetryExhausted { attempts: u32, last_error: String },
117 Other { message: String },
118}
119
120impl std::fmt::Display for ProviderError {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 match self {
123 ProviderError::Network { message } => write!(f, "Network error: {message}"),
124 ProviderError::Api { status, message } => write!(f, "API error ({status}): {message}"),
125 ProviderError::Parse { message } => write!(f, "Parse error: {message}"),
126 ProviderError::Timeout { elapsed_ms } => write!(f, "Timeout after {elapsed_ms}ms"),
127 ProviderError::RetryExhausted {
128 attempts,
129 last_error,
130 } => {
131 write!(f, "Retry exhausted after {attempts} attempts: {last_error}")
132 }
133 ProviderError::Other { message } => write!(f, "{message}"),
134 }
135 }
136}
137
138impl std::error::Error for ProviderError {}
139
140#[derive(Debug, Clone, Default, Serialize)]
143pub struct UsageAccount {
144 pub total_input_tokens: u64,
145 pub total_output_tokens: u64,
146 pub total_cache_read_tokens: u64,
147 pub total_cache_write_tokens: u64,
148 pub request_count: u64,
149}
150
151impl UsageAccount {
152 pub fn record(&mut self, usage: &UsageEvent) {
153 self.total_input_tokens += usage.input_tokens;
154 self.total_output_tokens += usage.output_tokens;
155 self.total_cache_read_tokens += usage.cache_read_tokens;
156 self.total_cache_write_tokens += usage.cache_write_tokens;
157 self.request_count += 1;
158 }
159}
160
161#[derive(Debug, Clone)]
164pub struct RetryConfig {
165 pub max_attempts: u32,
166 pub base_delay: Duration,
167 pub max_delay: Duration,
168}
169
170impl Default for RetryConfig {
171 fn default() -> Self {
172 Self {
173 max_attempts: 3,
174 base_delay: Duration::from_millis(200),
175 max_delay: Duration::from_secs(5),
176 }
177 }
178}
179
180#[derive(Debug, Clone)]
181pub struct TimeoutConfig {
182 pub per_request: Duration,
183}
184
185impl Default for TimeoutConfig {
186 fn default() -> Self {
187 Self {
188 per_request: Duration::from_secs(60),
189 }
190 }
191}
192
193pub type EventStream = Pin<Box<dyn Stream<Item = Result<ApiEvent, ProviderError>> + Send>>;
196
197pub trait Provider: Send + Sync {
198 fn complete(
199 &self,
200 request: &CompletionRequest,
201 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>;
202
203 fn stream(&self, request: &CompletionRequest) -> EventStream;
204}
205
206pub type BoxedProvider = Box<dyn Provider>;
209
210#[derive(Debug, Default, Clone)]
213pub struct MockProvider;
214
215impl Provider for MockProvider {
216 fn complete(
217 &self,
218 request: &CompletionRequest,
219 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
220 let response = mock_complete_response(request);
221 Box::pin(async move { Ok(response) })
222 }
223
224 fn stream(&self, request: &CompletionRequest) -> EventStream {
225 let req = request.clone();
226 let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
227
228 tokio::spawn(async move {
229 let events = mock_stream_events(&req);
230 for (idx, event) in events.into_iter().enumerate() {
231 if idx > 0 {
232 tokio::time::sleep(Duration::from_millis(18)).await;
233 }
234 if tx.send(Ok(event)).await.is_err() {
235 return;
236 }
237 }
238 });
239
240 Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
241 rx.recv().await.map(|item| (item, rx))
242 }))
243 }
244}
245
246fn mock_complete_response(request: &CompletionRequest) -> CompletionResponse {
247 if wants_read_cargo_toml(&request.prompt) {
248 if let Some(tool_content) = first_tool_result_content(&request.messages) {
249 let response = mock_summarize_cargo_toml(&tool_content);
250 return CompletionResponse {
251 system_prompt: request.system_prompt_name.clone(),
252 response,
253 tool_count: request.tools.len(),
254 skill_count: request.skill_count,
255 mcp_server_count: request.mcp_servers.len(),
256 };
257 }
258
259 return CompletionResponse {
260 system_prompt: request.system_prompt_name.clone(),
261 response: "I'll read Cargo.toml first.".to_string(),
262 tool_count: request.tools.len(),
263 skill_count: request.skill_count,
264 mcp_server_count: request.mcp_servers.len(),
265 };
266 }
267
268 let response = format!(
269 "Model: {}\nPrompt pack: {}\nSystem prompt: {}\nTools: {}\nSkills discovered: {}\nMCP servers discovered: {}\n\nRequest queued for the execution loop.\n\nNext priorities:\n1. Parse instructions into an explicit task graph.\n2. Resolve tool approvals before execution.\n3. Stream structured updates into the terminal UI.",
270 request.model,
271 request.prompt_pack,
272 request.system_prompt_name,
273 request
274 .tools
275 .iter()
276 .map(|tool| tool.name.as_str())
277 .collect::<Vec<_>>()
278 .join(", "),
279 request.skill_count,
280 request.mcp_servers.len(),
281 );
282
283 CompletionResponse {
284 system_prompt: request.system_prompt_name.clone(),
285 response,
286 tool_count: request.tools.len(),
287 skill_count: request.skill_count,
288 mcp_server_count: request.mcp_servers.len(),
289 }
290}
291
292fn mock_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
293 if wants_read_cargo_toml(&request.prompt) {
294 if let Some(tool_content) = first_tool_result_content(&request.messages) {
295 return vec![
296 ApiEvent::ThinkingDelta {
297 text: "I have the Cargo.toml contents; summarizing.".to_string(),
298 },
299 ApiEvent::MessageDelta {
300 text: mock_summarize_cargo_toml(&tool_content),
301 },
302 ApiEvent::Usage {
303 usage: UsageEvent {
304 input_tokens: 220,
305 output_tokens: 90,
306 cache_read_tokens: 0,
307 cache_write_tokens: 0,
308 },
309 },
310 ApiEvent::Completed,
311 ];
312 }
313
314 return vec![
315 ApiEvent::ThinkingDelta {
316 text: "I should read Cargo.toml to answer this.".to_string(),
317 },
318 ApiEvent::ToolUse {
319 tool_use: ToolUseEvent {
320 id: "tool_1".to_string(),
321 name: "read_file".to_string(),
322 input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
323 },
324 },
325 ApiEvent::Completed,
326 ];
327 }
328
329 let tool_names: Vec<&str> = request.tools.iter().map(|t| t.name.as_ref()).collect();
330
331 vec![
332 ApiEvent::ThinkingDelta {
333 text: "Let me start by analyzing the request.".to_string(),
334 },
335 ApiEvent::MessageDelta {
336 text: format!(
337 "Model: {}\nPrompt pack: {}\nSystem prompt: {}\nTools: {}\nSkills discovered: {}\nMCP servers discovered: {}\n",
338 request.model,
339 request.prompt_pack,
340 request.system_prompt_name,
341 tool_names.join(", "),
342 request.skill_count,
343 request.mcp_servers.len(),
344 ),
345 },
346 ApiEvent::MessageDelta {
347 text: "\nRequest queued for the execution loop.\n\nNext priorities:\n".to_string(),
348 },
349 ApiEvent::MessageDelta {
350 text: "1. Parse instructions into an explicit task graph.\n".to_string(),
351 },
352 ApiEvent::MessageDelta {
353 text: "2. Resolve tool approvals before execution.\n".to_string(),
354 },
355 ApiEvent::MessageDelta {
356 text: "3. Stream structured updates into the terminal UI.".to_string(),
357 },
358 ApiEvent::Usage {
359 usage: UsageEvent {
360 input_tokens: 120,
361 output_tokens: 85,
362 cache_read_tokens: 0,
363 cache_write_tokens: 0,
364 },
365 },
366 ApiEvent::Completed,
367 ]
368}
369
370fn wants_read_cargo_toml(prompt: &str) -> bool {
371 let p = prompt.to_ascii_lowercase();
372 p.contains("cargo.toml") && (p.contains("read") || p.contains("summarize"))
373}
374
375fn first_tool_result_content(messages: &[ProviderMessage]) -> Option<String> {
376 for m in messages {
377 for b in &m.content {
378 if let ProviderContentBlock::ToolResult { content, .. } = b {
379 return Some(content.clone());
380 }
381 }
382 }
383 None
384}
385
386fn mock_summarize_cargo_toml(contents: &str) -> String {
387 if contents.contains("[workspace]") {
388 let mut out = String::from("Cargo.toml defines a Rust workspace.\n");
389 if contents.contains("members") {
390 out.push_str("It declares workspace members; this repo is a multi-crate workspace.\n");
391 }
392 out.push_str("Key crates include: clawedcode (cli), clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, clawedcode-tui.");
393 out
394 } else {
395 "Cargo.toml does not look like a workspace manifest (no [workspace] section).".to_string()
396 }
397}
398
399#[derive(Debug, Default, Clone)]
410pub struct MockToolProvider;
411
412impl Provider for MockToolProvider {
413 fn complete(
414 &self,
415 request: &CompletionRequest,
416 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
417 let response = mock_tool_complete_response(request);
418 Box::pin(async move { Ok(response) })
419 }
420
421 fn stream(&self, request: &CompletionRequest) -> EventStream {
422 let req = request.clone();
423 let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
424
425 tokio::spawn(async move {
426 let events = mock_tool_stream_events(&req);
427 for (idx, event) in events.into_iter().enumerate() {
428 if idx > 0 {
429 tokio::time::sleep(Duration::from_millis(5)).await;
430 }
431 if tx.send(Ok(event)).await.is_err() {
432 return;
433 }
434 }
435 });
436
437 Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
438 rx.recv().await.map(|item| (item, rx))
439 }))
440 }
441}
442
443fn mock_tool_complete_response(request: &CompletionRequest) -> CompletionResponse {
444 if has_tool_result_message(request) {
445 let response = "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string();
446 return CompletionResponse {
447 system_prompt: request.system_prompt_name.clone(),
448 response,
449 tool_count: request.tools.len(),
450 skill_count: request.skill_count,
451 mcp_server_count: request.mcp_servers.len(),
452 };
453 }
454
455 CompletionResponse {
456 system_prompt: request.system_prompt_name.clone(),
457 response: "I'll read the Cargo.toml file.".to_string(),
458 tool_count: request.tools.len(),
459 skill_count: request.skill_count,
460 mcp_server_count: request.mcp_servers.len(),
461 }
462}
463
464fn mock_tool_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
465 if has_tool_result_message(request) {
466 return vec![
467 ApiEvent::ThinkingDelta {
468 text: "I have the file contents now.".to_string(),
469 },
470 ApiEvent::MessageDelta {
471 text: "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string(),
472 },
473 ApiEvent::Usage {
474 usage: UsageEvent {
475 input_tokens: 200,
476 output_tokens: 60,
477 cache_read_tokens: 0,
478 cache_write_tokens: 0,
479 },
480 },
481 ApiEvent::Completed,
482 ];
483 }
484
485 vec![
486 ApiEvent::ThinkingDelta {
487 text: "I should read the Cargo.toml file.".to_string(),
488 },
489 ApiEvent::ToolUse {
490 tool_use: ToolUseEvent {
491 id: "tool_1".to_string(),
492 name: "read_file".to_string(),
493 input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
494 },
495 },
496 ApiEvent::Usage {
497 usage: UsageEvent {
498 input_tokens: 100,
499 output_tokens: 30,
500 cache_read_tokens: 0,
501 cache_write_tokens: 0,
502 },
503 },
504 ApiEvent::Completed,
505 ]
506}
507
508fn has_tool_result_message(request: &CompletionRequest) -> bool {
511 request.messages.iter().any(|m| {
512 m.content
513 .iter()
514 .any(|b| matches!(b, ProviderContentBlock::ToolResult { .. }))
515 })
516}
517
518#[cfg(feature = "anthropic")]
521pub mod anthropic_provider {
522 use super::*;
523 use reqwest::Client;
524
525 #[derive(Debug, Clone)]
526 pub struct AnthropicProvider {
527 client: Client,
528 api_key: String,
529 endpoint: String,
530 anthropic_version: String,
531 retry: RetryConfig,
532 timeout: TimeoutConfig,
533 }
534
535 pub(crate) fn normalize_anthropic_endpoint(endpoint: &str) -> String {
536 if endpoint.contains("/v1/messages") {
537 endpoint.to_string()
538 } else {
539 let endpoint = endpoint.trim_end_matches('/');
540 format!("{}/v1/messages", endpoint)
541 }
542 }
543
544 impl AnthropicProvider {
545 pub fn from_env() -> Option<Self> {
546 let api_key = std::env::var("ANTHROPIC_API_KEY")
547 .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
548 .ok()?;
549 let endpoint = std::env::var("CLAWEDCODE_ANTHROPIC_ENDPOINT")
550 .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
551 .unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string());
552 let endpoint = normalize_anthropic_endpoint(&endpoint);
553 let anthropic_version = std::env::var("CLAWEDCODE_ANTHROPIC_VERSION")
554 .unwrap_or_else(|_| "2023-06-01".to_string());
555 Some(Self {
556 client: Client::new(),
557 api_key,
558 endpoint,
559 anthropic_version,
560 retry: RetryConfig::default(),
561 timeout: TimeoutConfig::default(),
562 })
563 }
564
565 pub fn new(api_key: String, endpoint: String) -> Self {
566 Self {
567 client: Client::new(),
568 api_key,
569 endpoint,
570 anthropic_version: "2023-06-01".to_string(),
571 retry: RetryConfig::default(),
572 timeout: TimeoutConfig::default(),
573 }
574 }
575 }
576
577 fn build_anthropic_tools(tools: &[ToolSpec]) -> serde_json::Value {
578 if tools.is_empty() {
579 return serde_json::Value::Null;
580 }
581 serde_json::Value::Array(
582 tools
583 .iter()
584 .map(|t| {
585 serde_json::json!({
586 "name": t.name,
587 "description": t.description,
588 "input_schema": t.input_schema,
589 })
590 })
591 .collect(),
592 )
593 }
594
595 fn sanitize_messages_for_anthropic(messages: &[ProviderMessage]) -> Vec<ProviderMessage> {
596 messages
597 .iter()
598 .map(|m| ProviderMessage {
599 role: m.role.clone(),
600 content: m
601 .content
602 .iter()
603 .filter(|b| !matches!(b, ProviderContentBlock::Thinking { .. }))
604 .cloned()
605 .collect(),
606 })
607 .filter(|m| !m.content.is_empty())
608 .collect()
609 }
610
611 impl Provider for AnthropicProvider {
612 fn complete(
613 &self,
614 request: &CompletionRequest,
615 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>
616 {
617 let req = request.clone();
618 let api_key = self.api_key.clone();
619 let client = self.client.clone();
620 let endpoint = self.endpoint.clone();
621 let anthropic_version = self.anthropic_version.clone();
622 let retry = self.retry.clone();
623 let timeout = self.timeout.clone();
624 Box::pin(async move {
625 let tools = build_anthropic_tools(&req.tools);
626 let messages = sanitize_messages_for_anthropic(&req.messages);
627 let messages = if messages.is_empty() {
628 serde_json::json!([{"role": "user", "content": req.prompt}])
629 } else {
630 serde_json::to_value(messages).unwrap_or_else(
631 |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
632 )
633 };
634 let mut body = serde_json::json!({
635 "model": req.model,
636 "max_tokens": 4096,
637 "system": req.system_prompt_body,
638 "messages": messages,
639 });
640 if !tools.is_null() {
641 body["tools"] = tools;
642 }
643
644 let mut last_err: Option<ProviderError> = None;
645
646 for attempt in 1..=retry.max_attempts {
647 let send_fut = client
648 .post(&endpoint)
649 .header("x-api-key", &api_key)
650 .header("anthropic-version", &anthropic_version)
651 .header("content-type", "application/json")
652 .json(&body)
653 .send();
654
655 let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
656 Ok(Ok(r)) => r,
657 Ok(Err(e)) => {
658 last_err = Some(ProviderError::Network {
659 message: e.to_string(),
660 });
661 if attempt < retry.max_attempts {
662 let backoff_ms = (retry.base_delay.as_millis() as u64)
663 .saturating_mul(1u64 << (attempt - 1));
664 tokio::time::sleep(Duration::from_millis(
665 backoff_ms.min(retry.max_delay.as_millis() as u64),
666 ))
667 .await;
668 continue;
669 }
670 break;
671 }
672 Err(_) => {
673 last_err = Some(ProviderError::Timeout {
674 elapsed_ms: timeout.per_request.as_millis() as u64,
675 });
676 if attempt < retry.max_attempts {
677 tokio::time::sleep(retry.base_delay).await;
678 continue;
679 }
680 break;
681 }
682 };
683
684 let status = resp.status().as_u16();
685 if !resp.status().is_success() {
686 let text = resp.text().await.unwrap_or_default();
687 let err = ProviderError::Api {
688 status,
689 message: text,
690 };
691 last_err = Some(err);
692
693 let retryable = (500..=599).contains(&status);
695 if retryable && attempt < retry.max_attempts {
696 tokio::time::sleep(retry.base_delay).await;
697 continue;
698 }
699 break;
700 }
701
702 let json: serde_json::Value =
703 resp.json().await.map_err(|e| ProviderError::Parse {
704 message: e.to_string(),
705 })?;
706
707 let response = json["content"]
708 .as_array()
709 .map(|arr| {
710 arr.iter()
711 .filter_map(|block| block["text"].as_str())
712 .collect::<Vec<_>>()
713 .join("")
714 })
715 .unwrap_or_default();
716
717 return Ok(CompletionResponse {
718 system_prompt: req.system_prompt_name.clone(),
719 response,
720 tool_count: req.tools.len(),
721 skill_count: req.skill_count,
722 mcp_server_count: req.mcp_servers.len(),
723 });
724 }
725
726 Err(match last_err {
727 Some(e) => ProviderError::RetryExhausted {
728 attempts: retry.max_attempts,
729 last_error: e.to_string(),
730 },
731 None => ProviderError::Other {
732 message: "request failed".to_string(),
733 },
734 })
735 })
736 }
737
738 fn stream(&self, request: &CompletionRequest) -> EventStream {
739 let req = request.clone();
740 let api_key = self.api_key.clone();
741 let client = self.client.clone();
742 let endpoint = self.endpoint.clone();
743 let anthropic_version = self.anthropic_version.clone();
744 let timeout = self.timeout.clone();
745
746 let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(64);
747
748 tokio::spawn(async move {
749 let tools = build_anthropic_tools(&req.tools);
750 let messages = sanitize_messages_for_anthropic(&req.messages);
751 let messages = if messages.is_empty() {
752 serde_json::json!([{"role": "user", "content": req.prompt}])
753 } else {
754 serde_json::to_value(messages).unwrap_or_else(
755 |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
756 )
757 };
758 let mut body = serde_json::json!({
759 "model": req.model,
760 "max_tokens": 4096,
761 "system": req.system_prompt_body,
762 "messages": messages,
763 "stream": true,
764 });
765 if !tools.is_null() {
766 body["tools"] = tools;
767 }
768
769 let send_fut = client
770 .post(&endpoint)
771 .header("x-api-key", &api_key)
772 .header("anthropic-version", &anthropic_version)
773 .header("content-type", "application/json")
774 .json(&body)
775 .send();
776
777 let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
778 Ok(Ok(r)) => r,
779 Ok(Err(e)) => {
780 let _ = tx
781 .send(Err(ProviderError::Network {
782 message: e.to_string(),
783 }))
784 .await;
785 return;
786 }
787 Err(_) => {
788 let _ = tx
789 .send(Err(ProviderError::Timeout {
790 elapsed_ms: timeout.per_request.as_millis() as u64,
791 }))
792 .await;
793 return;
794 }
795 };
796
797 if !resp.status().is_success() {
798 let status = resp.status().as_u16();
799 let text = resp.text().await.unwrap_or_default();
800 let _ = tx
801 .send(Err(ProviderError::Api {
802 status,
803 message: text,
804 }))
805 .await;
806 return;
807 }
808
809 let mut buf = String::new();
810 let mut bytes = resp.bytes_stream();
811 use futures_util::StreamExt;
812
813 let mut current_tool_use: Option<(String, String, String)> = None;
814
815 while let Some(chunk) = bytes.next().await {
816 let chunk = match chunk {
817 Ok(c) => c,
818 Err(e) => {
819 let _ = tx
820 .send(Err(ProviderError::Network {
821 message: e.to_string(),
822 }))
823 .await;
824 return;
825 }
826 };
827
828 buf.push_str(&String::from_utf8_lossy(&chunk));
829
830 while let Some(idx) = buf.find("\n\n") {
832 let frame: String = buf.drain(..(idx + 2)).collect();
833 let mut data_lines = Vec::new();
834 for line in frame.lines() {
835 let line = line.trim();
836 if let Some(rest) = line.strip_prefix("data:") {
837 let payload = rest.trim();
838 if !payload.is_empty() {
839 data_lines.push(payload.to_string());
840 }
841 }
842 }
843
844 if data_lines.is_empty() {
845 continue;
846 }
847
848 let data = data_lines.join("\n");
849 if data == "[DONE]" {
850 continue;
851 }
852
853 let Ok(event) = serde_json::from_str::<serde_json::Value>(&data) else {
854 continue;
855 };
856 let typ = event["type"].as_str().unwrap_or("");
857 match typ {
858 "content_block_start" => {
859 let cb_type = event["content_block"]["type"].as_str().unwrap_or("");
860 if cb_type == "text" {
861 if let Some(t) = event["content_block"]["text"].as_str() {
862 if !t.is_empty() {
863 let _ = tx
864 .send(Ok(ApiEvent::MessageDelta {
865 text: t.to_string(),
866 }))
867 .await;
868 }
869 }
870 } else if cb_type == "tool_use" {
871 let id = event["content_block"]["id"]
872 .as_str()
873 .unwrap_or("")
874 .to_string();
875 let name = event["content_block"]["name"]
876 .as_str()
877 .unwrap_or("")
878 .to_string();
879 let input = event["content_block"]["input"].clone();
880 let input_str = if input.is_null() {
881 String::new()
882 } else {
883 serde_json::to_string(&input).unwrap_or_default()
884 };
885 current_tool_use = Some((id, name, input_str));
886 } else if cb_type == "thinking" {
887 if let Some(t) = event["content_block"]["thinking"].as_str() {
888 if !t.is_empty() {
889 let _ = tx
890 .send(Ok(ApiEvent::ThinkingDelta {
891 text: t.to_string(),
892 }))
893 .await;
894 }
895 }
896 }
897 }
898 "content_block_delta" => {
899 let delta_type = event["delta"]["type"].as_str().unwrap_or("");
900 match delta_type {
901 "text_delta" => {
902 if let Some(t) = event["delta"]["text"].as_str() {
903 let _ = tx
904 .send(Ok(ApiEvent::MessageDelta {
905 text: t.to_string(),
906 }))
907 .await;
908 }
909 }
910 "thinking_delta" => {
911 if let Some(t) = event["delta"]["thinking"].as_str() {
912 let _ = tx
913 .send(Ok(ApiEvent::ThinkingDelta {
914 text: t.to_string(),
915 }))
916 .await;
917 }
918 }
919 "input_json_delta" => {
920 if let Some(partial) =
921 event["delta"]["partial_json"].as_str()
922 {
923 if let Some((_id, _name, input_buf)) =
924 current_tool_use.as_mut()
925 {
926 input_buf.push_str(partial);
927 }
928 }
929 }
930 _ => {}
931 }
932 }
933 "content_block_stop" => {
934 if let Some((id, name, input)) = current_tool_use.take() {
935 let input = if input.trim().is_empty() {
936 "{}".to_string()
937 } else {
938 input
939 };
940 let _ = tx
941 .send(Ok(ApiEvent::ToolUse {
942 tool_use: ToolUseEvent { id, name, input },
943 }))
944 .await;
945 }
946 }
947 "message_delta" => {
948 if let Some(usage) = event.get("usage") {
949 let _ = tx
950 .send(Ok(ApiEvent::Usage {
951 usage: UsageEvent {
952 input_tokens: usage["input_tokens"]
953 .as_u64()
954 .unwrap_or(0),
955 output_tokens: usage["output_tokens"]
956 .as_u64()
957 .unwrap_or(0),
958 cache_read_tokens: usage["cache_read_input_tokens"]
959 .as_u64()
960 .unwrap_or(0),
961 cache_write_tokens:
962 usage["cache_creation_input_tokens"]
963 .as_u64()
964 .unwrap_or(0),
965 },
966 }))
967 .await;
968 }
969 }
970 "message_stop" => {
971 let _ = tx.send(Ok(ApiEvent::Completed)).await;
972 return;
973 }
974 _ => {}
975 }
976 }
977 }
978
979 let _ = tx.send(Ok(ApiEvent::Completed)).await;
980 });
981
982 Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
983 rx.recv().await.map(|item| (item, rx))
984 }))
985 }
986 }
987}
988
989pub fn create_provider() -> BoxedProvider {
992 let provider_name = std::env::var("CLAWEDCODE_PROVIDER").unwrap_or_default();
993
994 match provider_name.as_str() {
995 #[cfg(feature = "anthropic")]
996 "anthropic" => {
997 if let Some(p) = anthropic_provider::AnthropicProvider::from_env() {
998 tracing::info!("Using Anthropic provider");
999 return Box::new(p);
1000 }
1001 tracing::warn!("ANTHROPIC_API_KEY not set, falling back to mock provider");
1002 }
1003 #[cfg(not(feature = "anthropic"))]
1004 "anthropic" => {
1005 tracing::warn!(
1006 "Anthropic provider requested but 'anthropic' feature not enabled, falling back to mock"
1007 );
1008 }
1009 _ => {}
1010 }
1011
1012 tracing::info!("Using mock provider");
1013 Box::new(MockProvider)
1014}
1015
1016pub async fn collect_stream_to_response(
1019 stream: EventStream,
1020 request: &CompletionRequest,
1021) -> Result<CompletionResponse, ProviderError> {
1022 use futures_util::StreamExt;
1023 let mut text = String::new();
1024 let mut thinking = String::new();
1025
1026 let mut s = stream;
1027 while let Some(event) = s.next().await {
1028 match event? {
1029 ApiEvent::MessageDelta { text: t } => text.push_str(&t),
1030 ApiEvent::ThinkingDelta { text: t } => thinking.push_str(&t),
1031 ApiEvent::Usage { usage: _ } => {}
1032 ApiEvent::Completed => break,
1033 ApiEvent::ToolUse { .. } | ApiEvent::ToolResult { .. } => {}
1034 }
1035 }
1036
1037 Ok(CompletionResponse {
1038 system_prompt: request.system_prompt_name.clone(),
1039 response: text,
1040 tool_count: request.tools.len(),
1041 skill_count: request.skill_count,
1042 mcp_server_count: request.mcp_servers.len(),
1043 })
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048 use super::*;
1049 use futures_util::StreamExt;
1050
1051 #[tokio::test]
1052 async fn mock_stream_has_multiple_deltas() {
1053 let provider = MockProvider;
1054 let request = CompletionRequest {
1055 model: "test-model".to_string(),
1056 prompt_pack: "default".to_string(),
1057 system_prompt_name: "default".to_string(),
1058 system_prompt_body: "You are helpful.".to_string(),
1059 prompt: "hello".to_string(),
1060 messages: vec![],
1061 tools: vec![],
1062 skill_count: 0,
1063 mcp_servers: BTreeMap::new(),
1064 };
1065
1066 let events: Vec<_> = provider
1067 .stream(&request)
1068 .filter_map(|e| async move { e.ok() })
1069 .collect()
1070 .await;
1071 assert!(!events.is_empty());
1072
1073 assert!(matches!(events.last(), Some(ApiEvent::Completed)));
1074 }
1075
1076 #[tokio::test]
1077 async fn mock_stream_event_ordering() {
1078 let provider = MockProvider;
1079 let request = CompletionRequest {
1080 model: "test-model".to_string(),
1081 prompt_pack: "default".to_string(),
1082 system_prompt_name: "default".to_string(),
1083 system_prompt_body: "You are helpful.".to_string(),
1084 prompt: "hello".to_string(),
1085 messages: vec![],
1086 tools: vec![],
1087 skill_count: 0,
1088 mcp_servers: BTreeMap::new(),
1089 };
1090
1091 let events: Vec<_> = provider
1092 .stream(&request)
1093 .filter_map(|e| async move { e.ok() })
1094 .collect()
1095 .await;
1096
1097 assert!(matches!(events[0], ApiEvent::ThinkingDelta { .. }));
1098 assert!(matches!(events[1], ApiEvent::MessageDelta { .. }));
1099
1100 let usage_idx = events
1101 .iter()
1102 .position(|e| matches!(e, ApiEvent::Usage { .. }))
1103 .expect("Usage event should exist");
1104 let completed_idx = events
1105 .iter()
1106 .position(|e| matches!(e, ApiEvent::Completed))
1107 .expect("Completed event should exist");
1108 assert!(usage_idx < completed_idx);
1109 }
1110
1111 #[tokio::test]
1112 async fn mock_stream_concatenated_text_matches_complete_response() {
1113 let provider = MockProvider;
1114 let request = CompletionRequest {
1115 model: "test-model".to_string(),
1116 prompt_pack: "default".to_string(),
1117 system_prompt_name: "default".to_string(),
1118 system_prompt_body: "You are helpful.".to_string(),
1119 prompt: "hello".to_string(),
1120 messages: vec![],
1121 tools: vec![],
1122 skill_count: 0,
1123 mcp_servers: BTreeMap::new(),
1124 };
1125
1126 let direct = provider.complete(&request).await.unwrap();
1127 let events: Vec<_> = provider
1128 .stream(&request)
1129 .filter_map(|e| async move { e.ok() })
1130 .collect()
1131 .await;
1132
1133 let mut text = String::new();
1134 for e in &events {
1135 if let ApiEvent::MessageDelta { text: t } = e {
1136 text.push_str(t);
1137 }
1138 }
1139
1140 assert_eq!(text, direct.response);
1141 }
1142
1143 #[test]
1144 fn usage_account_accumulates() {
1145 let mut account = UsageAccount::default();
1146 account.record(&UsageEvent {
1147 input_tokens: 100,
1148 output_tokens: 50,
1149 cache_read_tokens: 10,
1150 cache_write_tokens: 20,
1151 });
1152 account.record(&UsageEvent {
1153 input_tokens: 200,
1154 output_tokens: 75,
1155 cache_read_tokens: 0,
1156 cache_write_tokens: 0,
1157 });
1158
1159 assert_eq!(account.total_input_tokens, 300);
1160 assert_eq!(account.total_output_tokens, 125);
1161 assert_eq!(account.total_cache_read_tokens, 10);
1162 assert_eq!(account.total_cache_write_tokens, 20);
1163 assert_eq!(account.request_count, 2);
1164 }
1165
1166 #[test]
1167 fn provider_error_display() {
1168 let err = ProviderError::Timeout { elapsed_ms: 5000 };
1169 assert!(err.to_string().contains("5000"));
1170
1171 let err = ProviderError::RetryExhausted {
1172 attempts: 3,
1173 last_error: "timeout".to_string(),
1174 };
1175 assert!(err.to_string().contains("3"));
1176 }
1177
1178 #[cfg(feature = "anthropic")]
1179 mod anthropic_tests {
1180 use crate::anthropic_provider::normalize_anthropic_endpoint;
1181
1182 #[test]
1183 fn test_normalize_anthropic_endpoint_base_url() {
1184 assert_eq!(
1185 normalize_anthropic_endpoint("http://localhost:11434"),
1186 "http://localhost:11434/v1/messages"
1187 );
1188 }
1189
1190 #[test]
1191 fn test_normalize_anthropic_endpoint_trailing_slash() {
1192 assert_eq!(
1193 normalize_anthropic_endpoint("http://localhost:11434/"),
1194 "http://localhost:11434/v1/messages"
1195 );
1196 }
1197
1198 #[test]
1199 fn test_normalize_anthropic_endpoint_already_has_v1() {
1200 assert_eq!(
1201 normalize_anthropic_endpoint("http://localhost:11434/v1/messages"),
1202 "http://localhost:11434/v1/messages"
1203 );
1204 }
1205
1206 #[test]
1207 fn test_normalize_anthropic_endpoint_custom_path() {
1208 assert_eq!(
1209 normalize_anthropic_endpoint("https://api.anthropic.com"),
1210 "https://api.anthropic.com/v1/messages"
1211 );
1212 }
1213 }
1214}