1use clawedcode_mcp::McpServerConfig;
2use clawedcode_tools::ToolSpec;
3use futures_core::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::BTreeMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::time::Duration;
9use tokio::sync::mpsc;
10
11#[derive(Debug, Clone)]
14pub struct CompletionRequest {
15 pub model: String,
16 pub prompt_pack: String,
17 pub system_prompt_name: String,
18 pub system_prompt_body: String,
19 pub prompt: String,
20 pub messages: Vec<ProviderMessage>,
22 pub tools: Vec<ToolSpec>,
23 pub skill_count: usize,
24 pub mcp_servers: BTreeMap<String, McpServerConfig>,
25}
26
27#[derive(Debug, Clone, Serialize)]
28pub struct CompletionResponse {
29 pub system_prompt: String,
30 pub response: String,
31 pub tool_count: usize,
32 pub skill_count: usize,
33 pub mcp_server_count: usize,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum ProviderRole {
41 User,
42 Assistant,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46pub struct ProviderMessage {
47 pub role: ProviderRole,
48 pub content: Vec<ProviderContentBlock>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum ProviderContentBlock {
54 Text {
55 text: String,
56 },
57 ToolUse {
58 id: String,
59 name: String,
60 #[serde(default)]
61 input: serde_json::Value,
62 },
63 ToolResult {
64 tool_use_id: String,
65 content: String,
66 #[serde(default)]
67 is_error: bool,
68 },
69 Thinking {
70 thinking: String,
71 },
72}
73
74#[derive(Debug, Clone, Serialize)]
77pub struct ToolUseEvent {
78 pub id: String,
79 pub name: String,
80 pub input: String,
81}
82
83#[derive(Debug, Clone, Serialize)]
84pub struct ToolResultEvent {
85 pub tool_use_id: String,
86 pub content: String,
87 pub is_error: bool,
88}
89
90#[derive(Debug, Clone, Serialize)]
91pub struct UsageEvent {
92 pub input_tokens: u64,
93 pub output_tokens: u64,
94 pub cache_read_tokens: u64,
95 pub cache_write_tokens: u64,
96}
97
98#[derive(Debug, Clone, Serialize)]
99pub enum ApiEvent {
100 MessageDelta { text: String },
101 ThinkingDelta { text: String },
102 ToolUse { tool_use: ToolUseEvent },
103 ToolResult { tool_result: ToolResultEvent },
104 Usage { usage: UsageEvent },
105 Completed,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
111pub enum ProviderError {
112 Network { message: String },
113 Api { status: u16, message: String },
114 Parse { message: String },
115 Timeout { elapsed_ms: u64 },
116 RetryExhausted { attempts: u32, last_error: String },
117 Other { message: String },
118}
119
120impl std::fmt::Display for ProviderError {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 match self {
123 ProviderError::Network { message } => write!(f, "Network error: {message}"),
124 ProviderError::Api { status, message } => write!(f, "API error ({status}): {message}"),
125 ProviderError::Parse { message } => write!(f, "Parse error: {message}"),
126 ProviderError::Timeout { elapsed_ms } => write!(f, "Timeout after {elapsed_ms}ms"),
127 ProviderError::RetryExhausted {
128 attempts,
129 last_error,
130 } => {
131 write!(f, "Retry exhausted after {attempts} attempts: {last_error}")
132 }
133 ProviderError::Other { message } => write!(f, "{message}"),
134 }
135 }
136}
137
138impl std::error::Error for ProviderError {}
139
140#[derive(Debug, Clone, Default, Serialize)]
143pub struct UsageAccount {
144 pub total_input_tokens: u64,
145 pub total_output_tokens: u64,
146 pub total_cache_read_tokens: u64,
147 pub total_cache_write_tokens: u64,
148 pub request_count: u64,
149}
150
151impl UsageAccount {
152 pub fn record(&mut self, usage: &UsageEvent) {
153 self.total_input_tokens += usage.input_tokens;
154 self.total_output_tokens += usage.output_tokens;
155 self.total_cache_read_tokens += usage.cache_read_tokens;
156 self.total_cache_write_tokens += usage.cache_write_tokens;
157 self.request_count += 1;
158 }
159}
160
161#[derive(Debug, Clone)]
164pub struct RetryConfig {
165 pub max_attempts: u32,
166 pub base_delay: Duration,
167 pub max_delay: Duration,
168}
169
170impl Default for RetryConfig {
171 fn default() -> Self {
172 Self {
173 max_attempts: 3,
174 base_delay: Duration::from_millis(200),
175 max_delay: Duration::from_secs(5),
176 }
177 }
178}
179
180#[derive(Debug, Clone)]
181pub struct TimeoutConfig {
182 pub per_request: Duration,
183}
184
185impl Default for TimeoutConfig {
186 fn default() -> Self {
187 Self {
188 per_request: Duration::from_secs(60),
189 }
190 }
191}
192
193pub type EventStream = Pin<Box<dyn Stream<Item = Result<ApiEvent, ProviderError>> + Send>>;
196
197pub trait Provider: Send + Sync {
198 fn complete(
199 &self,
200 request: &CompletionRequest,
201 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>;
202
203 fn stream(&self, request: &CompletionRequest) -> EventStream;
204}
205
206pub type BoxedProvider = Box<dyn Provider>;
209
210#[derive(Debug, Default, Clone)]
213pub struct MockProvider;
214
215impl Provider for MockProvider {
216 fn complete(
217 &self,
218 request: &CompletionRequest,
219 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
220 let response = mock_complete_response(request);
221 Box::pin(async move { Ok(response) })
222 }
223
224 fn stream(&self, request: &CompletionRequest) -> EventStream {
225 let req = request.clone();
226 let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
227
228 tokio::spawn(async move {
229 let events = mock_stream_events(&req);
230 for (idx, event) in events.into_iter().enumerate() {
231 if idx > 0 {
232 tokio::time::sleep(Duration::from_millis(18)).await;
233 }
234 if tx.send(Ok(event)).await.is_err() {
235 return;
236 }
237 }
238 });
239
240 Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
241 rx.recv().await.map(|item| (item, rx))
242 }))
243 }
244}
245
246fn mock_complete_response(request: &CompletionRequest) -> CompletionResponse {
247 if wants_read_cargo_toml(&request.prompt) {
248 if let Some(tool_content) = first_tool_result_content(&request.messages) {
249 let response = mock_summarize_cargo_toml(&tool_content);
250 return CompletionResponse {
251 system_prompt: request.system_prompt_name.clone(),
252 response,
253 tool_count: request.tools.len(),
254 skill_count: request.skill_count,
255 mcp_server_count: request.mcp_servers.len(),
256 };
257 }
258
259 return CompletionResponse {
260 system_prompt: request.system_prompt_name.clone(),
261 response: "I'll read Cargo.toml first.".to_string(),
262 tool_count: request.tools.len(),
263 skill_count: request.skill_count,
264 mcp_server_count: request.mcp_servers.len(),
265 };
266 }
267
268 let response = mock_plain_reply(request);
269
270 CompletionResponse {
271 system_prompt: request.system_prompt_name.clone(),
272 response,
273 tool_count: request.tools.len(),
274 skill_count: request.skill_count,
275 mcp_server_count: request.mcp_servers.len(),
276 }
277}
278
279fn mock_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
280 if wants_read_cargo_toml(&request.prompt) {
281 if let Some(tool_content) = first_tool_result_content(&request.messages) {
282 return vec![
283 ApiEvent::ThinkingDelta {
284 text: "I have the Cargo.toml contents; summarizing.".to_string(),
285 },
286 ApiEvent::MessageDelta {
287 text: mock_summarize_cargo_toml(&tool_content),
288 },
289 ApiEvent::Usage {
290 usage: UsageEvent {
291 input_tokens: 220,
292 output_tokens: 90,
293 cache_read_tokens: 0,
294 cache_write_tokens: 0,
295 },
296 },
297 ApiEvent::Completed,
298 ];
299 }
300
301 return vec![
302 ApiEvent::ThinkingDelta {
303 text: "I should read Cargo.toml to answer this.".to_string(),
304 },
305 ApiEvent::ToolUse {
306 tool_use: ToolUseEvent {
307 id: "tool_1".to_string(),
308 name: "read_file".to_string(),
309 input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
310 },
311 },
312 ApiEvent::Completed,
313 ];
314 }
315
316 vec![
317 ApiEvent::ThinkingDelta {
318 text: "Thinking about how to help...".to_string(),
319 },
320 ApiEvent::MessageDelta {
321 text: mock_plain_reply(request),
322 },
323 ApiEvent::Usage {
324 usage: UsageEvent {
325 input_tokens: 80,
326 output_tokens: 40,
327 cache_read_tokens: 0,
328 cache_write_tokens: 0,
329 },
330 },
331 ApiEvent::Completed,
332 ]
333}
334
335fn mock_plain_reply(request: &CompletionRequest) -> String {
336 let prompt = request.prompt.trim().to_ascii_lowercase();
337
338 if prompt.is_empty() {
339 return "Hello! How can I help you today?".to_string();
340 }
341
342 if ["hello", "hi", "hey"]
343 .iter()
344 .any(|greeting| prompt == *greeting)
345 {
346 return "Hello! How can I assist you today?".to_string();
347 }
348
349 if prompt.contains("how are you") {
350 return "I'm doing well, thank you! What can I help you with?".to_string();
351 }
352
353 "I received your message. To get started with real AI-powered assistance, configure a provider like Claude or Ollama.".to_string()
354}
355
356fn wants_read_cargo_toml(prompt: &str) -> bool {
357 let p = prompt.to_ascii_lowercase();
358 p.contains("cargo.toml") && (p.contains("read") || p.contains("summarize"))
359}
360
361fn first_tool_result_content(messages: &[ProviderMessage]) -> Option<String> {
362 for m in messages {
363 for b in &m.content {
364 if let ProviderContentBlock::ToolResult { content, .. } = b {
365 return Some(content.clone());
366 }
367 }
368 }
369 None
370}
371
372fn mock_summarize_cargo_toml(contents: &str) -> String {
373 if contents.contains("[workspace]") {
374 let mut out = String::from("Cargo.toml defines a Rust workspace.\n");
375 if contents.contains("members") {
376 out.push_str("It declares workspace members; this repo is a multi-crate workspace.\n");
377 }
378 out.push_str("Key crates include: clawedcode (cli), clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, clawedcode-tui.");
379 out
380 } else {
381 "Cargo.toml does not look like a workspace manifest (no [workspace] section).".to_string()
382 }
383}
384
385#[derive(Debug, Default, Clone)]
396pub struct MockToolProvider;
397
398impl Provider for MockToolProvider {
399 fn complete(
400 &self,
401 request: &CompletionRequest,
402 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
403 let response = mock_tool_complete_response(request);
404 Box::pin(async move { Ok(response) })
405 }
406
407 fn stream(&self, request: &CompletionRequest) -> EventStream {
408 let req = request.clone();
409 let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
410
411 tokio::spawn(async move {
412 let events = mock_tool_stream_events(&req);
413 for (idx, event) in events.into_iter().enumerate() {
414 if idx > 0 {
415 tokio::time::sleep(Duration::from_millis(5)).await;
416 }
417 if tx.send(Ok(event)).await.is_err() {
418 return;
419 }
420 }
421 });
422
423 Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
424 rx.recv().await.map(|item| (item, rx))
425 }))
426 }
427}
428
429fn mock_tool_complete_response(request: &CompletionRequest) -> CompletionResponse {
430 if has_tool_result_message(request) {
431 let response = "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string();
432 return CompletionResponse {
433 system_prompt: request.system_prompt_name.clone(),
434 response,
435 tool_count: request.tools.len(),
436 skill_count: request.skill_count,
437 mcp_server_count: request.mcp_servers.len(),
438 };
439 }
440
441 CompletionResponse {
442 system_prompt: request.system_prompt_name.clone(),
443 response: "I'll read the Cargo.toml file.".to_string(),
444 tool_count: request.tools.len(),
445 skill_count: request.skill_count,
446 mcp_server_count: request.mcp_servers.len(),
447 }
448}
449
450fn mock_tool_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
451 if has_tool_result_message(request) {
452 return vec![
453 ApiEvent::ThinkingDelta {
454 text: "I have the file contents now.".to_string(),
455 },
456 ApiEvent::MessageDelta {
457 text: "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string(),
458 },
459 ApiEvent::Usage {
460 usage: UsageEvent {
461 input_tokens: 200,
462 output_tokens: 60,
463 cache_read_tokens: 0,
464 cache_write_tokens: 0,
465 },
466 },
467 ApiEvent::Completed,
468 ];
469 }
470
471 vec![
472 ApiEvent::ThinkingDelta {
473 text: "I should read the Cargo.toml file.".to_string(),
474 },
475 ApiEvent::ToolUse {
476 tool_use: ToolUseEvent {
477 id: "tool_1".to_string(),
478 name: "read_file".to_string(),
479 input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
480 },
481 },
482 ApiEvent::Usage {
483 usage: UsageEvent {
484 input_tokens: 100,
485 output_tokens: 30,
486 cache_read_tokens: 0,
487 cache_write_tokens: 0,
488 },
489 },
490 ApiEvent::Completed,
491 ]
492}
493
494fn has_tool_result_message(request: &CompletionRequest) -> bool {
497 request.messages.iter().any(|m| {
498 m.content
499 .iter()
500 .any(|b| matches!(b, ProviderContentBlock::ToolResult { .. }))
501 })
502}
503
504#[cfg(feature = "anthropic")]
507pub mod anthropic_provider {
508 use super::*;
509 use reqwest::Client;
510
511 #[derive(Debug, Clone)]
512 pub struct AnthropicProvider {
513 client: Client,
514 api_key: String,
515 endpoint: String,
516 anthropic_version: String,
517 retry: RetryConfig,
518 timeout: TimeoutConfig,
519 }
520
521 pub(crate) fn normalize_anthropic_endpoint(endpoint: &str) -> String {
522 if endpoint.contains("/v1/messages") {
523 endpoint.to_string()
524 } else {
525 let endpoint = endpoint.trim_end_matches('/');
526 format!("{}/v1/messages", endpoint)
527 }
528 }
529
530 impl AnthropicProvider {
531 pub fn from_env() -> Option<Self> {
532 let api_key = std::env::var("ANTHROPIC_API_KEY")
533 .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
534 .ok()?;
535 let endpoint = std::env::var("CLAWEDCODE_ANTHROPIC_ENDPOINT")
536 .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
537 .unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string());
538 let endpoint = normalize_anthropic_endpoint(&endpoint);
539 let anthropic_version = std::env::var("CLAWEDCODE_ANTHROPIC_VERSION")
540 .unwrap_or_else(|_| "2023-06-01".to_string());
541 Some(Self {
542 client: Client::new(),
543 api_key,
544 endpoint,
545 anthropic_version,
546 retry: RetryConfig::default(),
547 timeout: TimeoutConfig::default(),
548 })
549 }
550
551 pub fn new(api_key: String, endpoint: String) -> Self {
552 Self {
553 client: Client::new(),
554 api_key,
555 endpoint,
556 anthropic_version: "2023-06-01".to_string(),
557 retry: RetryConfig::default(),
558 timeout: TimeoutConfig::default(),
559 }
560 }
561 }
562
563 fn build_anthropic_tools(tools: &[ToolSpec]) -> serde_json::Value {
564 if tools.is_empty() {
565 return serde_json::Value::Null;
566 }
567 serde_json::Value::Array(
568 tools
569 .iter()
570 .map(|t| {
571 serde_json::json!({
572 "name": t.name,
573 "description": t.description,
574 "input_schema": t.input_schema,
575 })
576 })
577 .collect(),
578 )
579 }
580
581 fn sanitize_messages_for_anthropic(messages: &[ProviderMessage]) -> Vec<ProviderMessage> {
582 messages
583 .iter()
584 .map(|m| ProviderMessage {
585 role: m.role.clone(),
586 content: m
587 .content
588 .iter()
589 .filter(|b| !matches!(b, ProviderContentBlock::Thinking { .. }))
590 .cloned()
591 .collect(),
592 })
593 .filter(|m| !m.content.is_empty())
594 .collect()
595 }
596
597 impl Provider for AnthropicProvider {
598 fn complete(
599 &self,
600 request: &CompletionRequest,
601 ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>
602 {
603 let req = request.clone();
604 let api_key = self.api_key.clone();
605 let client = self.client.clone();
606 let endpoint = self.endpoint.clone();
607 let anthropic_version = self.anthropic_version.clone();
608 let retry = self.retry.clone();
609 let timeout = self.timeout.clone();
610 Box::pin(async move {
611 let tools = build_anthropic_tools(&req.tools);
612 let messages = sanitize_messages_for_anthropic(&req.messages);
613 let messages = if messages.is_empty() {
614 serde_json::json!([{"role": "user", "content": req.prompt}])
615 } else {
616 serde_json::to_value(messages).unwrap_or_else(
617 |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
618 )
619 };
620 let mut body = serde_json::json!({
621 "model": req.model,
622 "max_tokens": 4096,
623 "system": req.system_prompt_body,
624 "messages": messages,
625 });
626 if !tools.is_null() {
627 body["tools"] = tools;
628 }
629
630 let mut last_err: Option<ProviderError> = None;
631
632 for attempt in 1..=retry.max_attempts {
633 let send_fut = client
634 .post(&endpoint)
635 .header("x-api-key", &api_key)
636 .header("anthropic-version", &anthropic_version)
637 .header("content-type", "application/json")
638 .json(&body)
639 .send();
640
641 let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
642 Ok(Ok(r)) => r,
643 Ok(Err(e)) => {
644 last_err = Some(ProviderError::Network {
645 message: e.to_string(),
646 });
647 if attempt < retry.max_attempts {
648 let backoff_ms = (retry.base_delay.as_millis() as u64)
649 .saturating_mul(1u64 << (attempt - 1));
650 tokio::time::sleep(Duration::from_millis(
651 backoff_ms.min(retry.max_delay.as_millis() as u64),
652 ))
653 .await;
654 continue;
655 }
656 break;
657 }
658 Err(_) => {
659 last_err = Some(ProviderError::Timeout {
660 elapsed_ms: timeout.per_request.as_millis() as u64,
661 });
662 if attempt < retry.max_attempts {
663 tokio::time::sleep(retry.base_delay).await;
664 continue;
665 }
666 break;
667 }
668 };
669
670 let status = resp.status().as_u16();
671 if !resp.status().is_success() {
672 let text = resp.text().await.unwrap_or_default();
673 let err = ProviderError::Api {
674 status,
675 message: text,
676 };
677 last_err = Some(err);
678
679 let retryable = (500..=599).contains(&status);
681 if retryable && attempt < retry.max_attempts {
682 tokio::time::sleep(retry.base_delay).await;
683 continue;
684 }
685 break;
686 }
687
688 let json: serde_json::Value =
689 resp.json().await.map_err(|e| ProviderError::Parse {
690 message: e.to_string(),
691 })?;
692
693 let response = json["content"]
694 .as_array()
695 .map(|arr| {
696 arr.iter()
697 .filter_map(|block| block["text"].as_str())
698 .collect::<Vec<_>>()
699 .join("")
700 })
701 .unwrap_or_default();
702
703 return Ok(CompletionResponse {
704 system_prompt: req.system_prompt_name.clone(),
705 response,
706 tool_count: req.tools.len(),
707 skill_count: req.skill_count,
708 mcp_server_count: req.mcp_servers.len(),
709 });
710 }
711
712 Err(match last_err {
713 Some(e) => ProviderError::RetryExhausted {
714 attempts: retry.max_attempts,
715 last_error: e.to_string(),
716 },
717 None => ProviderError::Other {
718 message: "request failed".to_string(),
719 },
720 })
721 })
722 }
723
724 fn stream(&self, request: &CompletionRequest) -> EventStream {
725 let req = request.clone();
726 let api_key = self.api_key.clone();
727 let client = self.client.clone();
728 let endpoint = self.endpoint.clone();
729 let anthropic_version = self.anthropic_version.clone();
730 let timeout = self.timeout.clone();
731
732 let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(64);
733
734 tokio::spawn(async move {
735 let tools = build_anthropic_tools(&req.tools);
736 let messages = sanitize_messages_for_anthropic(&req.messages);
737 let messages = if messages.is_empty() {
738 serde_json::json!([{"role": "user", "content": req.prompt}])
739 } else {
740 serde_json::to_value(messages).unwrap_or_else(
741 |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
742 )
743 };
744 let mut body = serde_json::json!({
745 "model": req.model,
746 "max_tokens": 4096,
747 "system": req.system_prompt_body,
748 "messages": messages,
749 "stream": true,
750 });
751 if !tools.is_null() {
752 body["tools"] = tools;
753 }
754
755 let send_fut = client
756 .post(&endpoint)
757 .header("x-api-key", &api_key)
758 .header("anthropic-version", &anthropic_version)
759 .header("content-type", "application/json")
760 .json(&body)
761 .send();
762
763 let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
764 Ok(Ok(r)) => r,
765 Ok(Err(e)) => {
766 let _ = tx
767 .send(Err(ProviderError::Network {
768 message: e.to_string(),
769 }))
770 .await;
771 return;
772 }
773 Err(_) => {
774 let _ = tx
775 .send(Err(ProviderError::Timeout {
776 elapsed_ms: timeout.per_request.as_millis() as u64,
777 }))
778 .await;
779 return;
780 }
781 };
782
783 if !resp.status().is_success() {
784 let status = resp.status().as_u16();
785 let text = resp.text().await.unwrap_or_default();
786 let _ = tx
787 .send(Err(ProviderError::Api {
788 status,
789 message: text,
790 }))
791 .await;
792 return;
793 }
794
795 let mut buf = String::new();
796 let mut bytes = resp.bytes_stream();
797 use futures_util::StreamExt;
798
799 let mut current_tool_use: Option<(String, String, String)> = None;
800
801 while let Some(chunk) = bytes.next().await {
802 let chunk = match chunk {
803 Ok(c) => c,
804 Err(e) => {
805 let _ = tx
806 .send(Err(ProviderError::Network {
807 message: e.to_string(),
808 }))
809 .await;
810 return;
811 }
812 };
813
814 buf.push_str(&String::from_utf8_lossy(&chunk));
815
816 while let Some(idx) = buf.find("\n\n") {
818 let frame: String = buf.drain(..(idx + 2)).collect();
819 let mut data_lines = Vec::new();
820 for line in frame.lines() {
821 let line = line.trim();
822 if let Some(rest) = line.strip_prefix("data:") {
823 let payload = rest.trim();
824 if !payload.is_empty() {
825 data_lines.push(payload.to_string());
826 }
827 }
828 }
829
830 if data_lines.is_empty() {
831 continue;
832 }
833
834 let data = data_lines.join("\n");
835 if data == "[DONE]" {
836 continue;
837 }
838
839 let Ok(event) = serde_json::from_str::<serde_json::Value>(&data) else {
840 continue;
841 };
842 let typ = event["type"].as_str().unwrap_or("");
843 match typ {
844 "content_block_start" => {
845 let cb_type = event["content_block"]["type"].as_str().unwrap_or("");
846 if cb_type == "text" {
847 if let Some(t) = event["content_block"]["text"].as_str() {
848 if !t.is_empty() {
849 let _ = tx
850 .send(Ok(ApiEvent::MessageDelta {
851 text: t.to_string(),
852 }))
853 .await;
854 }
855 }
856 } else if cb_type == "tool_use" {
857 let id = event["content_block"]["id"]
858 .as_str()
859 .unwrap_or("")
860 .to_string();
861 let name = event["content_block"]["name"]
862 .as_str()
863 .unwrap_or("")
864 .to_string();
865 let input = event["content_block"]["input"].clone();
866 let input_str = if input.is_null() {
867 String::new()
868 } else {
869 serde_json::to_string(&input).unwrap_or_default()
870 };
871 current_tool_use = Some((id, name, input_str));
872 } else if cb_type == "thinking" {
873 if let Some(t) = event["content_block"]["thinking"].as_str() {
874 if !t.is_empty() {
875 let _ = tx
876 .send(Ok(ApiEvent::ThinkingDelta {
877 text: t.to_string(),
878 }))
879 .await;
880 }
881 }
882 }
883 }
884 "content_block_delta" => {
885 let delta_type = event["delta"]["type"].as_str().unwrap_or("");
886 match delta_type {
887 "text_delta" => {
888 if let Some(t) = event["delta"]["text"].as_str() {
889 let _ = tx
890 .send(Ok(ApiEvent::MessageDelta {
891 text: t.to_string(),
892 }))
893 .await;
894 }
895 }
896 "thinking_delta" => {
897 if let Some(t) = event["delta"]["thinking"].as_str() {
898 let _ = tx
899 .send(Ok(ApiEvent::ThinkingDelta {
900 text: t.to_string(),
901 }))
902 .await;
903 }
904 }
905 "input_json_delta" => {
906 if let Some(partial) =
907 event["delta"]["partial_json"].as_str()
908 {
909 if let Some((_id, _name, input_buf)) =
910 current_tool_use.as_mut()
911 {
912 input_buf.push_str(partial);
913 }
914 }
915 }
916 _ => {}
917 }
918 }
919 "content_block_stop" => {
920 if let Some((id, name, input)) = current_tool_use.take() {
921 let input = if input.trim().is_empty() {
922 "{}".to_string()
923 } else {
924 input
925 };
926 let _ = tx
927 .send(Ok(ApiEvent::ToolUse {
928 tool_use: ToolUseEvent { id, name, input },
929 }))
930 .await;
931 }
932 }
933 "message_delta" => {
934 if let Some(usage) = event.get("usage") {
935 let _ = tx
936 .send(Ok(ApiEvent::Usage {
937 usage: UsageEvent {
938 input_tokens: usage["input_tokens"]
939 .as_u64()
940 .unwrap_or(0),
941 output_tokens: usage["output_tokens"]
942 .as_u64()
943 .unwrap_or(0),
944 cache_read_tokens: usage["cache_read_input_tokens"]
945 .as_u64()
946 .unwrap_or(0),
947 cache_write_tokens:
948 usage["cache_creation_input_tokens"]
949 .as_u64()
950 .unwrap_or(0),
951 },
952 }))
953 .await;
954 }
955 }
956 "message_stop" => {
957 let _ = tx.send(Ok(ApiEvent::Completed)).await;
958 return;
959 }
960 _ => {}
961 }
962 }
963 }
964
965 let _ = tx.send(Ok(ApiEvent::Completed)).await;
966 });
967
968 Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
969 rx.recv().await.map(|item| (item, rx))
970 }))
971 }
972 }
973}
974
975pub fn create_provider() -> BoxedProvider {
978 let provider_name = std::env::var("CLAWEDCODE_PROVIDER").unwrap_or_default();
979
980 match provider_name.as_str() {
981 #[cfg(feature = "anthropic")]
982 "anthropic" => {
983 if let Some(p) = anthropic_provider::AnthropicProvider::from_env() {
984 tracing::info!("Using Anthropic provider");
985 return Box::new(p);
986 }
987 tracing::warn!("ANTHROPIC_API_KEY not set, falling back to mock provider");
988 }
989 #[cfg(not(feature = "anthropic"))]
990 "anthropic" => {
991 tracing::warn!(
992 "Anthropic provider requested but 'anthropic' feature not enabled, falling back to mock"
993 );
994 }
995 _ => {}
996 }
997
998 tracing::info!("Using mock provider");
999 Box::new(MockProvider)
1000}
1001
1002pub async fn collect_stream_to_response(
1005 stream: EventStream,
1006 request: &CompletionRequest,
1007) -> Result<CompletionResponse, ProviderError> {
1008 use futures_util::StreamExt;
1009 let mut text = String::new();
1010 let mut thinking = String::new();
1011
1012 let mut s = stream;
1013 while let Some(event) = s.next().await {
1014 match event? {
1015 ApiEvent::MessageDelta { text: t } => text.push_str(&t),
1016 ApiEvent::ThinkingDelta { text: t } => thinking.push_str(&t),
1017 ApiEvent::Usage { usage: _ } => {}
1018 ApiEvent::Completed => break,
1019 ApiEvent::ToolUse { .. } | ApiEvent::ToolResult { .. } => {}
1020 }
1021 }
1022
1023 Ok(CompletionResponse {
1024 system_prompt: request.system_prompt_name.clone(),
1025 response: text,
1026 tool_count: request.tools.len(),
1027 skill_count: request.skill_count,
1028 mcp_server_count: request.mcp_servers.len(),
1029 })
1030}
1031
1032#[cfg(test)]
1033mod tests {
1034 use super::*;
1035 use futures_util::StreamExt;
1036
1037 #[tokio::test]
1038 async fn mock_stream_has_multiple_deltas() {
1039 let provider = MockProvider;
1040 let request = CompletionRequest {
1041 model: "test-model".to_string(),
1042 prompt_pack: "default".to_string(),
1043 system_prompt_name: "default".to_string(),
1044 system_prompt_body: "You are helpful.".to_string(),
1045 prompt: "hello".to_string(),
1046 messages: vec![],
1047 tools: vec![],
1048 skill_count: 0,
1049 mcp_servers: BTreeMap::new(),
1050 };
1051
1052 let events: Vec<_> = provider
1053 .stream(&request)
1054 .filter_map(|e| async move { e.ok() })
1055 .collect()
1056 .await;
1057 assert!(!events.is_empty());
1058
1059 assert!(matches!(events.last(), Some(ApiEvent::Completed)));
1060 }
1061
1062 #[tokio::test]
1063 async fn mock_stream_event_ordering() {
1064 let provider = MockProvider;
1065 let request = CompletionRequest {
1066 model: "test-model".to_string(),
1067 prompt_pack: "default".to_string(),
1068 system_prompt_name: "default".to_string(),
1069 system_prompt_body: "You are helpful.".to_string(),
1070 prompt: "hello".to_string(),
1071 messages: vec![],
1072 tools: vec![],
1073 skill_count: 0,
1074 mcp_servers: BTreeMap::new(),
1075 };
1076
1077 let events: Vec<_> = provider
1078 .stream(&request)
1079 .filter_map(|e| async move { e.ok() })
1080 .collect()
1081 .await;
1082
1083 assert!(matches!(events[0], ApiEvent::ThinkingDelta { .. }));
1084 assert!(matches!(events[1], ApiEvent::MessageDelta { .. }));
1085
1086 let usage_idx = events
1087 .iter()
1088 .position(|e| matches!(e, ApiEvent::Usage { .. }))
1089 .expect("Usage event should exist");
1090 let completed_idx = events
1091 .iter()
1092 .position(|e| matches!(e, ApiEvent::Completed))
1093 .expect("Completed event should exist");
1094 assert!(usage_idx < completed_idx);
1095 }
1096
1097 #[tokio::test]
1098 async fn mock_stream_concatenated_text_matches_complete_response() {
1099 let provider = MockProvider;
1100 let request = CompletionRequest {
1101 model: "test-model".to_string(),
1102 prompt_pack: "default".to_string(),
1103 system_prompt_name: "default".to_string(),
1104 system_prompt_body: "You are helpful.".to_string(),
1105 prompt: "hello".to_string(),
1106 messages: vec![],
1107 tools: vec![],
1108 skill_count: 0,
1109 mcp_servers: BTreeMap::new(),
1110 };
1111
1112 let direct = provider.complete(&request).await.unwrap();
1113 let events: Vec<_> = provider
1114 .stream(&request)
1115 .filter_map(|e| async move { e.ok() })
1116 .collect()
1117 .await;
1118
1119 let mut text = String::new();
1120 for e in &events {
1121 if let ApiEvent::MessageDelta { text: t } = e {
1122 text.push_str(t);
1123 }
1124 }
1125
1126 assert_eq!(text, direct.response);
1127 }
1128
1129 #[test]
1130 fn usage_account_accumulates() {
1131 let mut account = UsageAccount::default();
1132 account.record(&UsageEvent {
1133 input_tokens: 100,
1134 output_tokens: 50,
1135 cache_read_tokens: 10,
1136 cache_write_tokens: 20,
1137 });
1138 account.record(&UsageEvent {
1139 input_tokens: 200,
1140 output_tokens: 75,
1141 cache_read_tokens: 0,
1142 cache_write_tokens: 0,
1143 });
1144
1145 assert_eq!(account.total_input_tokens, 300);
1146 assert_eq!(account.total_output_tokens, 125);
1147 assert_eq!(account.total_cache_read_tokens, 10);
1148 assert_eq!(account.total_cache_write_tokens, 20);
1149 assert_eq!(account.request_count, 2);
1150 }
1151
1152 #[test]
1153 fn provider_error_display() {
1154 let err = ProviderError::Timeout { elapsed_ms: 5000 };
1155 assert!(err.to_string().contains("5000"));
1156
1157 let err = ProviderError::RetryExhausted {
1158 attempts: 3,
1159 last_error: "timeout".to_string(),
1160 };
1161 assert!(err.to_string().contains("3"));
1162 }
1163
1164 #[cfg(feature = "anthropic")]
1165 mod anthropic_tests {
1166 use crate::anthropic_provider::normalize_anthropic_endpoint;
1167
1168 #[test]
1169 fn test_normalize_anthropic_endpoint_base_url() {
1170 assert_eq!(
1171 normalize_anthropic_endpoint("http://localhost:11434"),
1172 "http://localhost:11434/v1/messages"
1173 );
1174 }
1175
1176 #[test]
1177 fn test_normalize_anthropic_endpoint_trailing_slash() {
1178 assert_eq!(
1179 normalize_anthropic_endpoint("http://localhost:11434/"),
1180 "http://localhost:11434/v1/messages"
1181 );
1182 }
1183
1184 #[test]
1185 fn test_normalize_anthropic_endpoint_already_has_v1() {
1186 assert_eq!(
1187 normalize_anthropic_endpoint("http://localhost:11434/v1/messages"),
1188 "http://localhost:11434/v1/messages"
1189 );
1190 }
1191
1192 #[test]
1193 fn test_normalize_anthropic_endpoint_custom_path() {
1194 assert_eq!(
1195 normalize_anthropic_endpoint("https://api.anthropic.com"),
1196 "https://api.anthropic.com/v1/messages"
1197 );
1198 }
1199 }
1200}