1use clawedcode_mcp::McpServerConfig;
2use clawedcode_tools::ToolSpec;
3use futures_core::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::BTreeMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::time::Duration;
9use tokio::sync::mpsc;
10
11#[derive(Debug, Clone)]
14pub struct CompletionRequest {
15 pub model: String,
16 pub prompt_pack: String,
17 pub system_prompt_name: String,
18 pub system_prompt_body: String,
19 pub prompt: String,
20 pub messages: Vec<ProviderMessage>,
22 pub tools: Vec<ToolSpec>,
23 pub skill_count: usize,
24 pub mcp_servers: BTreeMap<String, McpServerConfig>,
25}
26
27#[derive(Debug, Clone, Serialize)]
28pub struct CompletionResponse {
29 pub system_prompt: String,
30 pub response: String,
31 pub tool_count: usize,
32 pub skill_count: usize,
33 pub mcp_server_count: usize,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum ProviderRole {
41 User,
42 Assistant,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46pub struct ProviderMessage {
47 pub role: ProviderRole,
48 pub content: Vec<ProviderContentBlock>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum ProviderContentBlock {
54 Text {
55 text: String,
56 },
57 ToolUse {
58 id: String,
59 name: String,
60 #[serde(default)]
61 input: serde_json::Value,
62 },
63 ToolResult {
64 tool_use_id: String,
65 content: String,
66 #[serde(default)]
67 is_error: bool,
68 },
69 Thinking {
70 thinking: String,
71 },
72}
73
74#[derive(Debug, Clone, Serialize)]
77pub struct ToolUseEvent {
78 pub id: String,
79 pub name: String,
80 pub input: String,
81}
82
83#[derive(Debug, Clone, Serialize)]
84pub struct ToolResultEvent {
85 pub tool_use_id: String,
86 pub content: String,
87 pub is_error: bool,
88}
89
90#[derive(Debug, Clone, Serialize)]
91pub struct UsageEvent {
92 pub input_tokens: u64,
93 pub output_tokens: u64,
94 pub cache_read_tokens: u64,
95 pub cache_write_tokens: u64,
96}
97
98#[derive(Debug, Clone, Serialize)]
99pub enum ApiEvent {
100 MessageDelta { text: String },
101 ThinkingDelta { text: String },
102 ToolUse { tool_use: ToolUseEvent },
103 ToolResult { tool_result: ToolResultEvent },
104 Usage { usage: UsageEvent },
105 Completed,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
111pub enum ProviderError {
112 Network { message: String },
113 Api { status: u16, message: String },
114 Parse { message: String },
115 Timeout { elapsed_ms: u64 },
116 RetryExhausted { attempts: u32, last_error: String },
117 Other { message: String },
118}
119
120impl std::fmt::Display for ProviderError {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 match self {
123 ProviderError::Network { message } => write!(f, "Network error: {message}"),
124 ProviderError::Api { status, message } => write!(f, "API error ({status}): {message}"),
125 ProviderError::Parse { message } => write!(f, "Parse error: {message}"),
126 ProviderError::Timeout { elapsed_ms } => write!(f, "Timeout after {elapsed_ms}ms"),
127 ProviderError::RetryExhausted {
128 attempts,
129 last_error,
130 } => {
131 write!(f, "Retry exhausted after {attempts} attempts: {last_error}")
132 }
133 ProviderError::Other { message } => write!(f, "{message}"),
134 }
135 }
136}
137
138impl std::error::Error for ProviderError {}
139
140#[derive(Debug, Clone, Default, Serialize)]
143pub struct UsageAccount {
144 pub total_input_tokens: u64,
145 pub total_output_tokens: u64,
146 pub total_cache_read_tokens: u64,
147 pub total_cache_write_tokens: u64,
148 pub request_count: u64,
149}
150
151impl UsageAccount {
152 pub fn record(&mut self, usage: &UsageEvent) {
153 self.total_input_tokens += usage.input_tokens;
154 self.total_output_tokens += usage.output_tokens;
155 self.total_cache_read_tokens += usage.cache_read_tokens;
156 self.total_cache_write_tokens += usage.cache_write_tokens;
157 self.request_count += 1;
158 }
159}
160
161#[derive(Debug, Clone)]
164pub struct RetryConfig {
165 pub max_attempts: u32,
166 pub base_delay: Duration,
167 pub max_delay: Duration,
168}
169
170impl Default for RetryConfig {
171 fn default() -> Self {
172 Self {
173 max_attempts: 3,
174 base_delay: Duration::from_millis(200),
175 max_delay: Duration::from_secs(5),
176 }
177 }
178}
179
180#[derive(Debug, Clone)]
181pub struct TimeoutConfig {
182 pub per_request: Duration,
183}
184
185impl Default for TimeoutConfig {
186 fn default() -> Self {
187 Self {
188 per_request: Duration::from_secs(60),
189 }
190 }
191}
192
193pub type EventStream = Pin<Box<dyn Stream<Item = Result<ApiEvent, ProviderError>> + Send>>;
196
197pub trait Provider: Send + Sync {
198 fn complete(
199 &self,
200 request: &CompletionRequest,
201 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>;
202
203 fn stream(&self, request: &CompletionRequest) -> EventStream;
204}
205
206pub type BoxedProvider = Box<dyn Provider>;
209
210#[derive(Debug, Default, Clone)]
213pub struct MockProvider;
214
215impl Provider for MockProvider {
216 fn complete(
217 &self,
218 request: &CompletionRequest,
219 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
220 let response = mock_complete_response(request);
221 Box::pin(async move { Ok(response) })
222 }
223
224 fn stream(&self, request: &CompletionRequest) -> EventStream {
225 let req = request.clone();
226 let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
227
228 tokio::spawn(async move {
229 let events = mock_stream_events(&req);
230 for (idx, event) in events.into_iter().enumerate() {
231 if idx > 0 {
232 tokio::time::sleep(Duration::from_millis(18)).await;
233 }
234 if tx.send(Ok(event)).await.is_err() {
235 return;
236 }
237 }
238 });
239
240 Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
241 rx.recv().await.map(|item| (item, rx))
242 }))
243 }
244}
245
246fn mock_complete_response(request: &CompletionRequest) -> CompletionResponse {
247 if wants_read_cargo_toml(&request.prompt) {
248 if let Some(tool_content) = first_tool_result_content(&request.messages) {
249 let response = mock_summarize_cargo_toml(&tool_content);
250 return CompletionResponse {
251 system_prompt: request.system_prompt_name.clone(),
252 response,
253 tool_count: request.tools.len(),
254 skill_count: request.skill_count,
255 mcp_server_count: request.mcp_servers.len(),
256 };
257 }
258
259 return CompletionResponse {
260 system_prompt: request.system_prompt_name.clone(),
261 response: "I'll read Cargo.toml first.".to_string(),
262 tool_count: request.tools.len(),
263 skill_count: request.skill_count,
264 mcp_server_count: request.mcp_servers.len(),
265 };
266 }
267
268 let response = mock_plain_reply(request);
269
270 CompletionResponse {
271 system_prompt: request.system_prompt_name.clone(),
272 response,
273 tool_count: request.tools.len(),
274 skill_count: request.skill_count,
275 mcp_server_count: request.mcp_servers.len(),
276 }
277}
278
279fn mock_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
280 if wants_read_cargo_toml(&request.prompt) {
281 if let Some(tool_content) = first_tool_result_content(&request.messages) {
282 return vec![
283 ApiEvent::ThinkingDelta {
284 text: "I have the Cargo.toml contents; summarizing.".to_string(),
285 },
286 ApiEvent::MessageDelta {
287 text: mock_summarize_cargo_toml(&tool_content),
288 },
289 ApiEvent::Usage {
290 usage: UsageEvent {
291 input_tokens: 220,
292 output_tokens: 90,
293 cache_read_tokens: 0,
294 cache_write_tokens: 0,
295 },
296 },
297 ApiEvent::Completed,
298 ];
299 }
300
301 return vec![
302 ApiEvent::ThinkingDelta {
303 text: "I should read Cargo.toml to answer this.".to_string(),
304 },
305 ApiEvent::ToolUse {
306 tool_use: ToolUseEvent {
307 id: "tool_1".to_string(),
308 name: "read_file".to_string(),
309 input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
310 },
311 },
312 ApiEvent::Completed,
313 ];
314 }
315
316 vec![
317 ApiEvent::ThinkingDelta {
318 text: "Thinking about how to help...".to_string(),
319 },
320 ApiEvent::MessageDelta {
321 text: mock_plain_reply(request),
322 },
323 ApiEvent::Usage {
324 usage: UsageEvent {
325 input_tokens: 80,
326 output_tokens: 40,
327 cache_read_tokens: 0,
328 cache_write_tokens: 0,
329 },
330 },
331 ApiEvent::Completed,
332 ]
333}
334
335fn mock_plain_reply(request: &CompletionRequest) -> String {
336 let prompt = request.prompt.trim().to_ascii_lowercase();
337
338 if prompt.is_empty() {
339 return "Hello! How can I help you today?".to_string();
340 }
341
342 if ["hello", "hi", "hey"]
343 .iter()
344 .any(|greeting| prompt == *greeting)
345 {
346 return "Hello! How can I assist you today?".to_string();
347 }
348
349 if prompt.contains("how are you") {
350 return "I'm doing well, thank you! What can I help you with?".to_string();
351 }
352
353 "I received your message. To get started with real AI-powered assistance, configure a provider like Claude or Ollama.".to_string()
354}
355
356fn wants_read_cargo_toml(prompt: &str) -> bool {
357 let p = prompt.to_ascii_lowercase();
358 p.contains("cargo.toml") && (p.contains("read") || p.contains("summarize"))
359}
360
361fn first_tool_result_content(messages: &[ProviderMessage]) -> Option<String> {
362 for m in messages {
363 for b in &m.content {
364 if let ProviderContentBlock::ToolResult { content, .. } = b {
365 return Some(content.clone());
366 }
367 }
368 }
369 None
370}
371
372fn mock_summarize_cargo_toml(contents: &str) -> String {
373 if contents.contains("[workspace]") {
374 let mut out = String::from("Cargo.toml defines a Rust workspace.\n");
375 if contents.contains("members") {
376 out.push_str("It declares workspace members; this repo is a multi-crate workspace.\n");
377 }
378 out.push_str("Key crates include: clawedcode (cli), clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, clawedcode-tui.");
379 out
380 } else {
381 "Cargo.toml does not look like a workspace manifest (no [workspace] section).".to_string()
382 }
383}
384
385#[derive(Debug, Default, Clone)]
396pub struct MockToolProvider;
397
398impl Provider for MockToolProvider {
399 fn complete(
400 &self,
401 request: &CompletionRequest,
402 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
403 let response = mock_tool_complete_response(request);
404 Box::pin(async move { Ok(response) })
405 }
406
407 fn stream(&self, request: &CompletionRequest) -> EventStream {
408 let req = request.clone();
409 let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
410
411 tokio::spawn(async move {
412 let events = mock_tool_stream_events(&req);
413 for (idx, event) in events.into_iter().enumerate() {
414 if idx > 0 {
415 tokio::time::sleep(Duration::from_millis(5)).await;
416 }
417 if tx.send(Ok(event)).await.is_err() {
418 return;
419 }
420 }
421 });
422
423 Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
424 rx.recv().await.map(|item| (item, rx))
425 }))
426 }
427}
428
429fn mock_tool_complete_response(request: &CompletionRequest) -> CompletionResponse {
430 if has_tool_result_message(request) {
431 let response = "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string();
432 return CompletionResponse {
433 system_prompt: request.system_prompt_name.clone(),
434 response,
435 tool_count: request.tools.len(),
436 skill_count: request.skill_count,
437 mcp_server_count: request.mcp_servers.len(),
438 };
439 }
440
441 CompletionResponse {
442 system_prompt: request.system_prompt_name.clone(),
443 response: "I'll read the Cargo.toml file.".to_string(),
444 tool_count: request.tools.len(),
445 skill_count: request.skill_count,
446 mcp_server_count: request.mcp_servers.len(),
447 }
448}
449
450fn mock_tool_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
451 if has_tool_result_message(request) {
452 return vec![
453 ApiEvent::ThinkingDelta {
454 text: "I have the file contents now.".to_string(),
455 },
456 ApiEvent::MessageDelta {
457 text: "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string(),
458 },
459 ApiEvent::Usage {
460 usage: UsageEvent {
461 input_tokens: 200,
462 output_tokens: 60,
463 cache_read_tokens: 0,
464 cache_write_tokens: 0,
465 },
466 },
467 ApiEvent::Completed,
468 ];
469 }
470
471 vec![
472 ApiEvent::ThinkingDelta {
473 text: "I should read the Cargo.toml file.".to_string(),
474 },
475 ApiEvent::ToolUse {
476 tool_use: ToolUseEvent {
477 id: "tool_1".to_string(),
478 name: "read_file".to_string(),
479 input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
480 },
481 },
482 ApiEvent::Usage {
483 usage: UsageEvent {
484 input_tokens: 100,
485 output_tokens: 30,
486 cache_read_tokens: 0,
487 cache_write_tokens: 0,
488 },
489 },
490 ApiEvent::Completed,
491 ]
492}
493
494fn has_tool_result_message(request: &CompletionRequest) -> bool {
497 request.messages.iter().any(|m| {
498 m.content
499 .iter()
500 .any(|b| matches!(b, ProviderContentBlock::ToolResult { .. }))
501 })
502}
503
504#[cfg(feature = "anthropic")]
507pub mod anthropic_provider {
508 use super::*;
509 use reqwest::Client;
510
511 #[derive(Debug, Clone)]
512 pub struct AnthropicProvider {
513 client: Client,
514 api_key: String,
515 endpoint: String,
516 anthropic_version: String,
517 retry: RetryConfig,
518 timeout: TimeoutConfig,
519 }
520
521 pub(crate) fn normalize_anthropic_endpoint(endpoint: &str) -> String {
522 if endpoint.contains("/v1/messages") {
523 endpoint.to_string()
524 } else {
525 let endpoint = endpoint.trim_end_matches('/');
526 format!("{}/v1/messages", endpoint)
527 }
528 }
529
530 impl AnthropicProvider {
531 pub fn from_env() -> Option<Self> {
532 let api_key = std::env::var("ANTHROPIC_API_KEY")
533 .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
534 .ok()?;
535 let endpoint = std::env::var("CLAWEDCODE_ANTHROPIC_ENDPOINT")
536 .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
537 .unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string());
538 let endpoint = normalize_anthropic_endpoint(&endpoint);
539 let anthropic_version = std::env::var("CLAWEDCODE_ANTHROPIC_VERSION")
540 .unwrap_or_else(|_| "2023-06-01".to_string());
541 Some(Self {
542 client: Client::new(),
543 api_key,
544 endpoint,
545 anthropic_version,
546 retry: RetryConfig::default(),
547 timeout: TimeoutConfig::default(),
548 })
549 }
550
551 pub fn new(api_key: String, endpoint: String) -> Self {
552 Self {
553 client: Client::new(),
554 api_key,
555 endpoint,
556 anthropic_version: "2023-06-01".to_string(),
557 retry: RetryConfig::default(),
558 timeout: TimeoutConfig::default(),
559 }
560 }
561 }
562
563 fn build_anthropic_tools(tools: &[ToolSpec]) -> serde_json::Value {
564 if tools.is_empty() {
565 return serde_json::Value::Null;
566 }
567 serde_json::Value::Array(
568 tools
569 .iter()
570 .map(|t| {
571 serde_json::json!({
572 "name": t.name,
573 "description": t.description,
574 "input_schema": t.input_schema,
575 })
576 })
577 .collect(),
578 )
579 }
580
581 fn sanitize_messages_for_anthropic(messages: &[ProviderMessage]) -> Vec<ProviderMessage> {
582 messages
583 .iter()
584 .map(|m| ProviderMessage {
585 role: m.role.clone(),
586 content: m
587 .content
588 .iter()
589 .filter(|b| !matches!(b, ProviderContentBlock::Thinking { .. }))
590 .cloned()
591 .collect(),
592 })
593 .filter(|m| !m.content.is_empty())
594 .collect()
595 }
596
597 pub(crate) fn initial_tool_input_buffer(input: &serde_json::Value) -> String {
598 match input {
599 serde_json::Value::Null => String::new(),
600 serde_json::Value::Object(map) if map.is_empty() => String::new(),
601 serde_json::Value::String(text) if text.trim().is_empty() => String::new(),
602 other => serde_json::to_string(other).unwrap_or_default(),
603 }
604 }
605
606 pub(crate) fn finalize_tool_input_buffer(input: String) -> String {
607 if input.trim().is_empty() {
608 "{}".to_string()
609 } else {
610 input
611 }
612 }
613
614 impl Provider for AnthropicProvider {
615 fn complete(
616 &self,
617 request: &CompletionRequest,
618 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>
619 {
620 let req = request.clone();
621 let api_key = self.api_key.clone();
622 let client = self.client.clone();
623 let endpoint = self.endpoint.clone();
624 let anthropic_version = self.anthropic_version.clone();
625 let retry = self.retry.clone();
626 let timeout = self.timeout.clone();
627 Box::pin(async move {
628 let tools = build_anthropic_tools(&req.tools);
629 let messages = sanitize_messages_for_anthropic(&req.messages);
630 let messages = if messages.is_empty() {
631 serde_json::json!([{"role": "user", "content": req.prompt}])
632 } else {
633 serde_json::to_value(messages).unwrap_or_else(
634 |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
635 )
636 };
637 let mut body = serde_json::json!({
638 "model": req.model,
639 "max_tokens": 4096,
640 "system": req.system_prompt_body,
641 "messages": messages,
642 });
643 if !tools.is_null() {
644 body["tools"] = tools;
645 }
646
647 let mut last_err: Option<ProviderError> = None;
648
649 for attempt in 1..=retry.max_attempts {
650 let send_fut = client
651 .post(&endpoint)
652 .header("x-api-key", &api_key)
653 .header("anthropic-version", &anthropic_version)
654 .header("content-type", "application/json")
655 .json(&body)
656 .send();
657
658 let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
659 Ok(Ok(r)) => r,
660 Ok(Err(e)) => {
661 last_err = Some(ProviderError::Network {
662 message: e.to_string(),
663 });
664 if attempt < retry.max_attempts {
665 let backoff_ms = (retry.base_delay.as_millis() as u64)
666 .saturating_mul(1u64 << (attempt - 1));
667 tokio::time::sleep(Duration::from_millis(
668 backoff_ms.min(retry.max_delay.as_millis() as u64),
669 ))
670 .await;
671 continue;
672 }
673 break;
674 }
675 Err(_) => {
676 last_err = Some(ProviderError::Timeout {
677 elapsed_ms: timeout.per_request.as_millis() as u64,
678 });
679 if attempt < retry.max_attempts {
680 tokio::time::sleep(retry.base_delay).await;
681 continue;
682 }
683 break;
684 }
685 };
686
687 let status = resp.status().as_u16();
688 if !resp.status().is_success() {
689 let text = resp.text().await.unwrap_or_default();
690 let err = ProviderError::Api {
691 status,
692 message: text,
693 };
694 last_err = Some(err);
695
696 let retryable = (500..=599).contains(&status);
698 if retryable && attempt < retry.max_attempts {
699 tokio::time::sleep(retry.base_delay).await;
700 continue;
701 }
702 break;
703 }
704
705 let json: serde_json::Value =
706 resp.json().await.map_err(|e| ProviderError::Parse {
707 message: e.to_string(),
708 })?;
709
710 let response = json["content"]
711 .as_array()
712 .map(|arr| {
713 arr.iter()
714 .filter_map(|block| block["text"].as_str())
715 .collect::<Vec<_>>()
716 .join("")
717 })
718 .unwrap_or_default();
719
720 return Ok(CompletionResponse {
721 system_prompt: req.system_prompt_name.clone(),
722 response,
723 tool_count: req.tools.len(),
724 skill_count: req.skill_count,
725 mcp_server_count: req.mcp_servers.len(),
726 });
727 }
728
729 Err(match last_err {
730 Some(e) => ProviderError::RetryExhausted {
731 attempts: retry.max_attempts,
732 last_error: e.to_string(),
733 },
734 None => ProviderError::Other {
735 message: "request failed".to_string(),
736 },
737 })
738 })
739 }
740
741 fn stream(&self, request: &CompletionRequest) -> EventStream {
742 let req = request.clone();
743 let api_key = self.api_key.clone();
744 let client = self.client.clone();
745 let endpoint = self.endpoint.clone();
746 let anthropic_version = self.anthropic_version.clone();
747 let timeout = self.timeout.clone();
748
749 let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(64);
750
751 tokio::spawn(async move {
752 let tools = build_anthropic_tools(&req.tools);
753 let messages = sanitize_messages_for_anthropic(&req.messages);
754 let messages = if messages.is_empty() {
755 serde_json::json!([{"role": "user", "content": req.prompt}])
756 } else {
757 serde_json::to_value(messages).unwrap_or_else(
758 |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
759 )
760 };
761 let mut body = serde_json::json!({
762 "model": req.model,
763 "max_tokens": 4096,
764 "system": req.system_prompt_body,
765 "messages": messages,
766 "stream": true,
767 });
768 if !tools.is_null() {
769 body["tools"] = tools;
770 }
771
772 let send_fut = client
773 .post(&endpoint)
774 .header("x-api-key", &api_key)
775 .header("anthropic-version", &anthropic_version)
776 .header("content-type", "application/json")
777 .json(&body)
778 .send();
779
780 let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
781 Ok(Ok(r)) => r,
782 Ok(Err(e)) => {
783 let _ = tx
784 .send(Err(ProviderError::Network {
785 message: e.to_string(),
786 }))
787 .await;
788 return;
789 }
790 Err(_) => {
791 let _ = tx
792 .send(Err(ProviderError::Timeout {
793 elapsed_ms: timeout.per_request.as_millis() as u64,
794 }))
795 .await;
796 return;
797 }
798 };
799
800 if !resp.status().is_success() {
801 let status = resp.status().as_u16();
802 let text = resp.text().await.unwrap_or_default();
803 let _ = tx
804 .send(Err(ProviderError::Api {
805 status,
806 message: text,
807 }))
808 .await;
809 return;
810 }
811
812 let mut buf = String::new();
813 let mut bytes = resp.bytes_stream();
814 use futures_util::StreamExt;
815
816 let mut current_tool_use: Option<(String, String, String)> = None;
817
818 while let Some(chunk) = bytes.next().await {
819 let chunk = match chunk {
820 Ok(c) => c,
821 Err(e) => {
822 let _ = tx
823 .send(Err(ProviderError::Network {
824 message: e.to_string(),
825 }))
826 .await;
827 return;
828 }
829 };
830
831 buf.push_str(&String::from_utf8_lossy(&chunk));
832
833 while let Some(idx) = buf.find("\n\n") {
835 let frame: String = buf.drain(..(idx + 2)).collect();
836 let mut data_lines = Vec::new();
837 for line in frame.lines() {
838 let line = line.trim();
839 if let Some(rest) = line.strip_prefix("data:") {
840 let payload = rest.trim();
841 if !payload.is_empty() {
842 data_lines.push(payload.to_string());
843 }
844 }
845 }
846
847 if data_lines.is_empty() {
848 continue;
849 }
850
851 let data = data_lines.join("\n");
852 if data == "[DONE]" {
853 continue;
854 }
855
856 let Ok(event) = serde_json::from_str::<serde_json::Value>(&data) else {
857 continue;
858 };
859 let typ = event["type"].as_str().unwrap_or("");
860 match typ {
861 "content_block_start" => {
862 let cb_type = event["content_block"]["type"].as_str().unwrap_or("");
863 if cb_type == "text" {
864 if let Some(t) = event["content_block"]["text"].as_str() {
865 if !t.is_empty() {
866 let _ = tx
867 .send(Ok(ApiEvent::MessageDelta {
868 text: t.to_string(),
869 }))
870 .await;
871 }
872 }
873 } else if cb_type == "tool_use" {
874 let id = event["content_block"]["id"]
875 .as_str()
876 .unwrap_or("")
877 .to_string();
878 let name = event["content_block"]["name"]
879 .as_str()
880 .unwrap_or("")
881 .to_string();
882 let input =
883 initial_tool_input_buffer(&event["content_block"]["input"]);
884 let input_str = input;
885 current_tool_use = Some((id, name, input_str));
886 } else if cb_type == "thinking" {
887 if let Some(t) = event["content_block"]["thinking"].as_str() {
888 if !t.is_empty() {
889 let _ = tx
890 .send(Ok(ApiEvent::ThinkingDelta {
891 text: t.to_string(),
892 }))
893 .await;
894 }
895 }
896 }
897 }
898 "content_block_delta" => {
899 let delta_type = event["delta"]["type"].as_str().unwrap_or("");
900 match delta_type {
901 "text_delta" => {
902 if let Some(t) = event["delta"]["text"].as_str() {
903 let _ = tx
904 .send(Ok(ApiEvent::MessageDelta {
905 text: t.to_string(),
906 }))
907 .await;
908 }
909 }
910 "thinking_delta" => {
911 if let Some(t) = event["delta"]["thinking"].as_str() {
912 let _ = tx
913 .send(Ok(ApiEvent::ThinkingDelta {
914 text: t.to_string(),
915 }))
916 .await;
917 }
918 }
919 "input_json_delta" => {
920 if let Some(partial) =
921 event["delta"]["partial_json"].as_str()
922 {
923 if let Some((_id, _name, input_buf)) =
924 current_tool_use.as_mut()
925 {
926 input_buf.push_str(partial);
927 }
928 }
929 }
930 _ => {}
931 }
932 }
933 "content_block_stop" => {
934 if let Some((id, name, input)) = current_tool_use.take() {
935 let input = finalize_tool_input_buffer(input);
936 let _ = tx
937 .send(Ok(ApiEvent::ToolUse {
938 tool_use: ToolUseEvent { id, name, input },
939 }))
940 .await;
941 }
942 }
943 "message_delta" => {
944 if let Some(usage) = event.get("usage") {
945 let _ = tx
946 .send(Ok(ApiEvent::Usage {
947 usage: UsageEvent {
948 input_tokens: usage["input_tokens"]
949 .as_u64()
950 .unwrap_or(0),
951 output_tokens: usage["output_tokens"]
952 .as_u64()
953 .unwrap_or(0),
954 cache_read_tokens: usage["cache_read_input_tokens"]
955 .as_u64()
956 .unwrap_or(0),
957 cache_write_tokens:
958 usage["cache_creation_input_tokens"]
959 .as_u64()
960 .unwrap_or(0),
961 },
962 }))
963 .await;
964 }
965 }
966 "message_stop" => {
967 let _ = tx.send(Ok(ApiEvent::Completed)).await;
968 return;
969 }
970 _ => {}
971 }
972 }
973 }
974
975 let _ = tx.send(Ok(ApiEvent::Completed)).await;
976 });
977
978 Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
979 rx.recv().await.map(|item| (item, rx))
980 }))
981 }
982 }
983}
984
985pub fn create_provider() -> BoxedProvider {
988 let provider_name = std::env::var("CLAWEDCODE_PROVIDER").unwrap_or_default();
989
990 match provider_name.as_str() {
991 #[cfg(feature = "anthropic")]
992 "anthropic" => {
993 if let Some(p) = anthropic_provider::AnthropicProvider::from_env() {
994 tracing::info!("Using Anthropic provider");
995 return Box::new(p);
996 }
997 tracing::warn!("ANTHROPIC_API_KEY not set, falling back to mock provider");
998 }
999 #[cfg(not(feature = "anthropic"))]
1000 "anthropic" => {
1001 tracing::warn!(
1002 "Anthropic provider requested but 'anthropic' feature not enabled, falling back to mock"
1003 );
1004 }
1005 _ => {}
1006 }
1007
1008 tracing::info!("Using mock provider");
1009 Box::new(MockProvider)
1010}
1011
1012pub async fn collect_stream_to_response(
1015 stream: EventStream,
1016 request: &CompletionRequest,
1017) -> Result<CompletionResponse, ProviderError> {
1018 use futures_util::StreamExt;
1019 let mut text = String::new();
1020 let mut thinking = String::new();
1021
1022 let mut s = stream;
1023 while let Some(event) = s.next().await {
1024 match event? {
1025 ApiEvent::MessageDelta { text: t } => text.push_str(&t),
1026 ApiEvent::ThinkingDelta { text: t } => thinking.push_str(&t),
1027 ApiEvent::Usage { usage: _ } => {}
1028 ApiEvent::Completed => break,
1029 ApiEvent::ToolUse { .. } | ApiEvent::ToolResult { .. } => {}
1030 }
1031 }
1032
1033 Ok(CompletionResponse {
1034 system_prompt: request.system_prompt_name.clone(),
1035 response: text,
1036 tool_count: request.tools.len(),
1037 skill_count: request.skill_count,
1038 mcp_server_count: request.mcp_servers.len(),
1039 })
1040}
1041
1042#[cfg(test)]
1043mod tests {
1044 use super::*;
1045 use futures_util::StreamExt;
1046
1047 #[tokio::test]
1048 async fn mock_stream_has_multiple_deltas() {
1049 let provider = MockProvider;
1050 let request = CompletionRequest {
1051 model: "test-model".to_string(),
1052 prompt_pack: "default".to_string(),
1053 system_prompt_name: "default".to_string(),
1054 system_prompt_body: "You are helpful.".to_string(),
1055 prompt: "hello".to_string(),
1056 messages: vec![],
1057 tools: vec![],
1058 skill_count: 0,
1059 mcp_servers: BTreeMap::new(),
1060 };
1061
1062 let events: Vec<_> = provider
1063 .stream(&request)
1064 .filter_map(|e| async move { e.ok() })
1065 .collect()
1066 .await;
1067 assert!(!events.is_empty());
1068
1069 assert!(matches!(events.last(), Some(ApiEvent::Completed)));
1070 }
1071
1072 #[tokio::test]
1073 async fn mock_stream_event_ordering() {
1074 let provider = MockProvider;
1075 let request = CompletionRequest {
1076 model: "test-model".to_string(),
1077 prompt_pack: "default".to_string(),
1078 system_prompt_name: "default".to_string(),
1079 system_prompt_body: "You are helpful.".to_string(),
1080 prompt: "hello".to_string(),
1081 messages: vec![],
1082 tools: vec![],
1083 skill_count: 0,
1084 mcp_servers: BTreeMap::new(),
1085 };
1086
1087 let events: Vec<_> = provider
1088 .stream(&request)
1089 .filter_map(|e| async move { e.ok() })
1090 .collect()
1091 .await;
1092
1093 assert!(matches!(events[0], ApiEvent::ThinkingDelta { .. }));
1094 assert!(matches!(events[1], ApiEvent::MessageDelta { .. }));
1095
1096 let usage_idx = events
1097 .iter()
1098 .position(|e| matches!(e, ApiEvent::Usage { .. }))
1099 .expect("Usage event should exist");
1100 let completed_idx = events
1101 .iter()
1102 .position(|e| matches!(e, ApiEvent::Completed))
1103 .expect("Completed event should exist");
1104 assert!(usage_idx < completed_idx);
1105 }
1106
1107 #[tokio::test]
1108 async fn mock_stream_concatenated_text_matches_complete_response() {
1109 let provider = MockProvider;
1110 let request = CompletionRequest {
1111 model: "test-model".to_string(),
1112 prompt_pack: "default".to_string(),
1113 system_prompt_name: "default".to_string(),
1114 system_prompt_body: "You are helpful.".to_string(),
1115 prompt: "hello".to_string(),
1116 messages: vec![],
1117 tools: vec![],
1118 skill_count: 0,
1119 mcp_servers: BTreeMap::new(),
1120 };
1121
1122 let direct = provider.complete(&request).await.unwrap();
1123 let events: Vec<_> = provider
1124 .stream(&request)
1125 .filter_map(|e| async move { e.ok() })
1126 .collect()
1127 .await;
1128
1129 let mut text = String::new();
1130 for e in &events {
1131 if let ApiEvent::MessageDelta { text: t } = e {
1132 text.push_str(t);
1133 }
1134 }
1135
1136 assert_eq!(text, direct.response);
1137 }
1138
1139 #[test]
1140 fn usage_account_accumulates() {
1141 let mut account = UsageAccount::default();
1142 account.record(&UsageEvent {
1143 input_tokens: 100,
1144 output_tokens: 50,
1145 cache_read_tokens: 10,
1146 cache_write_tokens: 20,
1147 });
1148 account.record(&UsageEvent {
1149 input_tokens: 200,
1150 output_tokens: 75,
1151 cache_read_tokens: 0,
1152 cache_write_tokens: 0,
1153 });
1154
1155 assert_eq!(account.total_input_tokens, 300);
1156 assert_eq!(account.total_output_tokens, 125);
1157 assert_eq!(account.total_cache_read_tokens, 10);
1158 assert_eq!(account.total_cache_write_tokens, 20);
1159 assert_eq!(account.request_count, 2);
1160 }
1161
1162 #[test]
1163 fn provider_error_display() {
1164 let err = ProviderError::Timeout { elapsed_ms: 5000 };
1165 assert!(err.to_string().contains("5000"));
1166
1167 let err = ProviderError::RetryExhausted {
1168 attempts: 3,
1169 last_error: "timeout".to_string(),
1170 };
1171 assert!(err.to_string().contains("3"));
1172 }
1173
1174 #[cfg(feature = "anthropic")]
1175 mod anthropic_tests {
1176 use crate::anthropic_provider::{
1177 finalize_tool_input_buffer, initial_tool_input_buffer, normalize_anthropic_endpoint,
1178 };
1179
1180 #[test]
1181 fn test_normalize_anthropic_endpoint_base_url() {
1182 assert_eq!(
1183 normalize_anthropic_endpoint("http://localhost:11434"),
1184 "http://localhost:11434/v1/messages"
1185 );
1186 }
1187
1188 #[test]
1189 fn test_normalize_anthropic_endpoint_trailing_slash() {
1190 assert_eq!(
1191 normalize_anthropic_endpoint("http://localhost:11434/"),
1192 "http://localhost:11434/v1/messages"
1193 );
1194 }
1195
1196 #[test]
1197 fn test_normalize_anthropic_endpoint_already_has_v1() {
1198 assert_eq!(
1199 normalize_anthropic_endpoint("http://localhost:11434/v1/messages"),
1200 "http://localhost:11434/v1/messages"
1201 );
1202 }
1203
1204 #[test]
1205 fn test_normalize_anthropic_endpoint_custom_path() {
1206 assert_eq!(
1207 normalize_anthropic_endpoint("https://api.anthropic.com"),
1208 "https://api.anthropic.com/v1/messages"
1209 );
1210 }
1211
1212 #[test]
1213 fn initial_tool_input_buffer_drops_empty_object() {
1214 assert_eq!(
1215 initial_tool_input_buffer(&serde_json::json!({})),
1216 String::new()
1217 );
1218 }
1219
1220 #[test]
1221 fn initial_tool_input_buffer_keeps_non_empty_object() {
1222 assert_eq!(
1223 initial_tool_input_buffer(&serde_json::json!({"command": "ls -la"})),
1224 r#"{"command":"ls -la"}"#
1225 );
1226 }
1227
1228 #[test]
1229 fn finalize_tool_input_buffer_defaults_empty_to_object() {
1230 assert_eq!(finalize_tool_input_buffer(String::new()), "{}");
1231 }
1232 }
1233}