Skip to main content

clawedcode_api/
lib.rs

1use clawedcode_mcp::McpServerConfig;
2use clawedcode_tools::ToolSpec;
3use futures_core::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::BTreeMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::time::Duration;
9use tokio::sync::mpsc;
10
11// --- Request / Response types ---
12
13#[derive(Debug, Clone)]
14pub struct CompletionRequest {
15    pub model: String,
16    pub prompt_pack: String,
17    pub system_prompt_name: String,
18    pub system_prompt_body: String,
19    pub prompt: String,
20    /// Structured conversation history. Providers may prefer this over `prompt`.
21    pub messages: Vec<ProviderMessage>,
22    pub tools: Vec<ToolSpec>,
23    pub skill_count: usize,
24    pub mcp_servers: BTreeMap<String, McpServerConfig>,
25}
26
27#[derive(Debug, Clone, Serialize)]
28pub struct CompletionResponse {
29    pub system_prompt: String,
30    pub response: String,
31    pub tool_count: usize,
32    pub skill_count: usize,
33    pub mcp_server_count: usize,
34}
35
36// --- Conversation message model (provider-facing) ---
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum ProviderRole {
41    User,
42    Assistant,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46pub struct ProviderMessage {
47    pub role: ProviderRole,
48    pub content: Vec<ProviderContentBlock>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum ProviderContentBlock {
54    Text {
55        text: String,
56    },
57    ToolUse {
58        id: String,
59        name: String,
60        #[serde(default)]
61        input: serde_json::Value,
62    },
63    ToolResult {
64        tool_use_id: String,
65        content: String,
66        #[serde(default)]
67        is_error: bool,
68    },
69    Thinking {
70        thinking: String,
71    },
72}
73
74// --- Streaming event model ---
75
76#[derive(Debug, Clone, Serialize)]
77pub struct ToolUseEvent {
78    pub id: String,
79    pub name: String,
80    pub input: String,
81}
82
83#[derive(Debug, Clone, Serialize)]
84pub struct ToolResultEvent {
85    pub tool_use_id: String,
86    pub content: String,
87    pub is_error: bool,
88}
89
90#[derive(Debug, Clone, Serialize)]
91pub struct UsageEvent {
92    pub input_tokens: u64,
93    pub output_tokens: u64,
94    pub cache_read_tokens: u64,
95    pub cache_write_tokens: u64,
96}
97
98#[derive(Debug, Clone, Serialize)]
99pub enum ApiEvent {
100    MessageDelta { text: String },
101    ThinkingDelta { text: String },
102    ToolUse { tool_use: ToolUseEvent },
103    ToolResult { tool_result: ToolResultEvent },
104    Usage { usage: UsageEvent },
105    Completed,
106}
107
108// --- Error types ---
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub enum ProviderError {
112    Network { message: String },
113    Api { status: u16, message: String },
114    Parse { message: String },
115    Timeout { elapsed_ms: u64 },
116    RetryExhausted { attempts: u32, last_error: String },
117    Other { message: String },
118}
119
120impl std::fmt::Display for ProviderError {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        match self {
123            ProviderError::Network { message } => write!(f, "Network error: {message}"),
124            ProviderError::Api { status, message } => write!(f, "API error ({status}): {message}"),
125            ProviderError::Parse { message } => write!(f, "Parse error: {message}"),
126            ProviderError::Timeout { elapsed_ms } => write!(f, "Timeout after {elapsed_ms}ms"),
127            ProviderError::RetryExhausted {
128                attempts,
129                last_error,
130            } => {
131                write!(f, "Retry exhausted after {attempts} attempts: {last_error}")
132            }
133            ProviderError::Other { message } => write!(f, "{message}"),
134        }
135    }
136}
137
138impl std::error::Error for ProviderError {}
139
140// --- Usage accounting ---
141
142#[derive(Debug, Clone, Default, Serialize)]
143pub struct UsageAccount {
144    pub total_input_tokens: u64,
145    pub total_output_tokens: u64,
146    pub total_cache_read_tokens: u64,
147    pub total_cache_write_tokens: u64,
148    pub request_count: u64,
149}
150
151impl UsageAccount {
152    pub fn record(&mut self, usage: &UsageEvent) {
153        self.total_input_tokens += usage.input_tokens;
154        self.total_output_tokens += usage.output_tokens;
155        self.total_cache_read_tokens += usage.cache_read_tokens;
156        self.total_cache_write_tokens += usage.cache_write_tokens;
157        self.request_count += 1;
158    }
159}
160
161// --- Retry / Timeout envelope ---
162
163#[derive(Debug, Clone)]
164pub struct RetryConfig {
165    pub max_attempts: u32,
166    pub base_delay: Duration,
167    pub max_delay: Duration,
168}
169
170impl Default for RetryConfig {
171    fn default() -> Self {
172        Self {
173            max_attempts: 3,
174            base_delay: Duration::from_millis(200),
175            max_delay: Duration::from_secs(5),
176        }
177    }
178}
179
180#[derive(Debug, Clone)]
181pub struct TimeoutConfig {
182    pub per_request: Duration,
183}
184
185impl Default for TimeoutConfig {
186    fn default() -> Self {
187        Self {
188            per_request: Duration::from_secs(60),
189        }
190    }
191}
192
193// --- Provider trait (async, streaming) ---
194
195pub type EventStream = Pin<Box<dyn Stream<Item = Result<ApiEvent, ProviderError>> + Send>>;
196
197pub trait Provider: Send + Sync {
198    fn complete(
199        &self,
200        request: &CompletionRequest,
201    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>;
202
203    fn stream(&self, request: &CompletionRequest) -> EventStream;
204}
205
206// --- Boxed provider handle ---
207
208pub type BoxedProvider = Box<dyn Provider>;
209
210// --- Mock provider (streams with delays) ---
211
212#[derive(Debug, Default, Clone)]
213pub struct MockProvider;
214
215impl Provider for MockProvider {
216    fn complete(
217        &self,
218        request: &CompletionRequest,
219    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
220        let response = mock_complete_response(request);
221        Box::pin(async move { Ok(response) })
222    }
223
224    fn stream(&self, request: &CompletionRequest) -> EventStream {
225        let req = request.clone();
226        let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
227
228        tokio::spawn(async move {
229            let events = mock_stream_events(&req);
230            for (idx, event) in events.into_iter().enumerate() {
231                if idx > 0 {
232                    tokio::time::sleep(Duration::from_millis(18)).await;
233                }
234                if tx.send(Ok(event)).await.is_err() {
235                    return;
236                }
237            }
238        });
239
240        Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
241            rx.recv().await.map(|item| (item, rx))
242        }))
243    }
244}
245
246fn mock_complete_response(request: &CompletionRequest) -> CompletionResponse {
247    if wants_read_cargo_toml(&request.prompt) {
248        if let Some(tool_content) = first_tool_result_content(&request.messages) {
249            let response = mock_summarize_cargo_toml(&tool_content);
250            return CompletionResponse {
251                system_prompt: request.system_prompt_name.clone(),
252                response,
253                tool_count: request.tools.len(),
254                skill_count: request.skill_count,
255                mcp_server_count: request.mcp_servers.len(),
256            };
257        }
258
259        return CompletionResponse {
260            system_prompt: request.system_prompt_name.clone(),
261            response: "I'll read Cargo.toml first.".to_string(),
262            tool_count: request.tools.len(),
263            skill_count: request.skill_count,
264            mcp_server_count: request.mcp_servers.len(),
265        };
266    }
267
268    let response = format!(
269        "Model: {}\nPrompt pack: {}\nSystem prompt: {}\nTools: {}\nSkills discovered: {}\nMCP servers discovered: {}\n\nRequest queued for the execution loop.\n\nNext priorities:\n1. Parse instructions into an explicit task graph.\n2. Resolve tool approvals before execution.\n3. Stream structured updates into the terminal UI.",
270        request.model,
271        request.prompt_pack,
272        request.system_prompt_name,
273        request
274            .tools
275            .iter()
276            .map(|tool| tool.name)
277            .collect::<Vec<_>>()
278            .join(", "),
279        request.skill_count,
280        request.mcp_servers.len(),
281    );
282
283    CompletionResponse {
284        system_prompt: request.system_prompt_name.clone(),
285        response,
286        tool_count: request.tools.len(),
287        skill_count: request.skill_count,
288        mcp_server_count: request.mcp_servers.len(),
289    }
290}
291
292fn mock_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
293    if wants_read_cargo_toml(&request.prompt) {
294        if let Some(tool_content) = first_tool_result_content(&request.messages) {
295            return vec![
296                ApiEvent::ThinkingDelta {
297                    text: "I have the Cargo.toml contents; summarizing.".to_string(),
298                },
299                ApiEvent::MessageDelta {
300                    text: mock_summarize_cargo_toml(&tool_content),
301                },
302                ApiEvent::Usage {
303                    usage: UsageEvent {
304                        input_tokens: 220,
305                        output_tokens: 90,
306                        cache_read_tokens: 0,
307                        cache_write_tokens: 0,
308                    },
309                },
310                ApiEvent::Completed,
311            ];
312        }
313
314        return vec![
315            ApiEvent::ThinkingDelta {
316                text: "I should read Cargo.toml to answer this.".to_string(),
317            },
318            ApiEvent::ToolUse {
319                tool_use: ToolUseEvent {
320                    id: "tool_1".to_string(),
321                    name: "read_file".to_string(),
322                    input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
323                },
324            },
325            ApiEvent::Completed,
326        ];
327    }
328
329    let tool_names: Vec<&str> = request.tools.iter().map(|t| t.name.as_ref()).collect();
330
331    vec![
332        ApiEvent::ThinkingDelta {
333            text: "Let me start by analyzing the request.".to_string(),
334        },
335        ApiEvent::MessageDelta {
336            text: format!(
337                "Model: {}\nPrompt pack: {}\nSystem prompt: {}\nTools: {}\nSkills discovered: {}\nMCP servers discovered: {}\n",
338                request.model,
339                request.prompt_pack,
340                request.system_prompt_name,
341                tool_names.join(", "),
342                request.skill_count,
343                request.mcp_servers.len(),
344            ),
345        },
346        ApiEvent::MessageDelta {
347            text: "\nRequest queued for the execution loop.\n\nNext priorities:\n".to_string(),
348        },
349        ApiEvent::MessageDelta {
350            text: "1. Parse instructions into an explicit task graph.\n".to_string(),
351        },
352        ApiEvent::MessageDelta {
353            text: "2. Resolve tool approvals before execution.\n".to_string(),
354        },
355        ApiEvent::MessageDelta {
356            text: "3. Stream structured updates into the terminal UI.".to_string(),
357        },
358        ApiEvent::Usage {
359            usage: UsageEvent {
360                input_tokens: 120,
361                output_tokens: 85,
362                cache_read_tokens: 0,
363                cache_write_tokens: 0,
364            },
365        },
366        ApiEvent::Completed,
367    ]
368}
369
370fn wants_read_cargo_toml(prompt: &str) -> bool {
371    let p = prompt.to_ascii_lowercase();
372    p.contains("cargo.toml") && (p.contains("read") || p.contains("summarize"))
373}
374
375fn first_tool_result_content(messages: &[ProviderMessage]) -> Option<String> {
376    for m in messages {
377        for b in &m.content {
378            if let ProviderContentBlock::ToolResult { content, .. } = b {
379                return Some(content.clone());
380            }
381        }
382    }
383    None
384}
385
386fn mock_summarize_cargo_toml(contents: &str) -> String {
387    if contents.contains("[workspace]") {
388        let mut out = String::from("Cargo.toml defines a Rust workspace.\n");
389        if contents.contains("members") {
390            out.push_str("It declares workspace members; this repo is a multi-crate workspace.\n");
391        }
392        out.push_str("Key crates include: clawedcode (cli), clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, clawedcode-tui.");
393        out
394    } else {
395        "Cargo.toml does not look like a workspace manifest (no [workspace] section).".to_string()
396    }
397}
398
399// --- MockToolProvider: deterministic provider for tests that triggers a tool call ---
400
401/// A mock provider that, on the first call, emits a `read_file` ToolUse for
402/// `Cargo.toml`. On subsequent calls (i.e. after the runtime has appended a
403/// tool_result message), it returns a final text response that references the
404/// tool result.
405///
406/// Detection of "subsequent call" is done by checking whether the session
407/// already contains a `tool` role message (injected by the runtime after tool
408/// execution).
409#[derive(Debug, Default, Clone)]
410pub struct MockToolProvider;
411
412impl Provider for MockToolProvider {
413    fn complete(
414        &self,
415        request: &CompletionRequest,
416    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
417        let response = mock_tool_complete_response(request);
418        Box::pin(async move { Ok(response) })
419    }
420
421    fn stream(&self, request: &CompletionRequest) -> EventStream {
422        let req = request.clone();
423        let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
424
425        tokio::spawn(async move {
426            let events = mock_tool_stream_events(&req);
427            for (idx, event) in events.into_iter().enumerate() {
428                if idx > 0 {
429                    tokio::time::sleep(Duration::from_millis(5)).await;
430                }
431                if tx.send(Ok(event)).await.is_err() {
432                    return;
433                }
434            }
435        });
436
437        Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
438            rx.recv().await.map(|item| (item, rx))
439        }))
440    }
441}
442
443fn mock_tool_complete_response(request: &CompletionRequest) -> CompletionResponse {
444    if has_tool_result_message(request) {
445        let response = "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string();
446        return CompletionResponse {
447            system_prompt: request.system_prompt_name.clone(),
448            response,
449            tool_count: request.tools.len(),
450            skill_count: request.skill_count,
451            mcp_server_count: request.mcp_servers.len(),
452        };
453    }
454
455    CompletionResponse {
456        system_prompt: request.system_prompt_name.clone(),
457        response: "I'll read the Cargo.toml file.".to_string(),
458        tool_count: request.tools.len(),
459        skill_count: request.skill_count,
460        mcp_server_count: request.mcp_servers.len(),
461    }
462}
463
464fn mock_tool_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
465    if has_tool_result_message(request) {
466        return vec![
467            ApiEvent::ThinkingDelta {
468                text: "I have the file contents now.".to_string(),
469            },
470            ApiEvent::MessageDelta {
471                text: "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string(),
472            },
473            ApiEvent::Usage {
474                usage: UsageEvent {
475                    input_tokens: 200,
476                    output_tokens: 60,
477                    cache_read_tokens: 0,
478                    cache_write_tokens: 0,
479                },
480            },
481            ApiEvent::Completed,
482        ];
483    }
484
485    vec![
486        ApiEvent::ThinkingDelta {
487            text: "I should read the Cargo.toml file.".to_string(),
488        },
489        ApiEvent::ToolUse {
490            tool_use: ToolUseEvent {
491                id: "tool_1".to_string(),
492                name: "read_file".to_string(),
493                input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
494            },
495        },
496        ApiEvent::Usage {
497            usage: UsageEvent {
498                input_tokens: 100,
499                output_tokens: 30,
500                cache_read_tokens: 0,
501                cache_write_tokens: 0,
502            },
503        },
504        ApiEvent::Completed,
505    ]
506}
507
508/// Returns true if the conversation already contains tool result blocks,
509/// indicating this is a re-query after tool execution.
510fn has_tool_result_message(request: &CompletionRequest) -> bool {
511    request.messages.iter().any(|m| {
512        m.content
513            .iter()
514            .any(|b| matches!(b, ProviderContentBlock::ToolResult { .. }))
515    })
516}
517
518// --- Optional Anthropic provider (behind feature flag + env var) ---
519
520#[cfg(feature = "anthropic")]
521pub mod anthropic_provider {
522    use super::*;
523    use reqwest::Client;
524
525    #[derive(Debug, Clone)]
526    pub struct AnthropicProvider {
527        client: Client,
528        api_key: String,
529        endpoint: String,
530        anthropic_version: String,
531        retry: RetryConfig,
532        timeout: TimeoutConfig,
533    }
534
535    impl AnthropicProvider {
536        pub fn from_env() -> Option<Self> {
537            let api_key = std::env::var("ANTHROPIC_API_KEY").ok()?;
538            let endpoint = std::env::var("CLAWEDCODE_ANTHROPIC_ENDPOINT")
539                .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
540                .unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string());
541            let anthropic_version = std::env::var("CLAWEDCODE_ANTHROPIC_VERSION")
542                .unwrap_or_else(|_| "2023-06-01".to_string());
543            Some(Self {
544                client: Client::new(),
545                api_key,
546                endpoint,
547                anthropic_version,
548                retry: RetryConfig::default(),
549                timeout: TimeoutConfig::default(),
550            })
551        }
552
553        pub fn new(api_key: String, endpoint: String) -> Self {
554            Self {
555                client: Client::new(),
556                api_key,
557                endpoint,
558                anthropic_version: "2023-06-01".to_string(),
559                retry: RetryConfig::default(),
560                timeout: TimeoutConfig::default(),
561            }
562        }
563    }
564
565    fn build_anthropic_tools(tools: &[ToolSpec]) -> serde_json::Value {
566        if tools.is_empty() {
567            return serde_json::Value::Null;
568        }
569        serde_json::Value::Array(
570            tools
571                .iter()
572                .map(|t| {
573                    serde_json::json!({
574                        "name": t.name,
575                        "description": t.description,
576                        "input_schema": t.input_schema,
577                    })
578                })
579                .collect(),
580        )
581    }
582
583    fn sanitize_messages_for_anthropic(messages: &[ProviderMessage]) -> Vec<ProviderMessage> {
584        messages
585            .iter()
586            .map(|m| ProviderMessage {
587                role: m.role.clone(),
588                content: m
589                    .content
590                    .iter()
591                    .filter(|b| !matches!(b, ProviderContentBlock::Thinking { .. }))
592                    .cloned()
593                    .collect(),
594            })
595            .filter(|m| !m.content.is_empty())
596            .collect()
597    }
598
599    impl Provider for AnthropicProvider {
600        fn complete(
601            &self,
602            request: &CompletionRequest,
603        ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>
604        {
605            let req = request.clone();
606            let api_key = self.api_key.clone();
607            let client = self.client.clone();
608            let endpoint = self.endpoint.clone();
609            let anthropic_version = self.anthropic_version.clone();
610            let retry = self.retry.clone();
611            let timeout = self.timeout.clone();
612            Box::pin(async move {
613                let tools = build_anthropic_tools(&req.tools);
614                let messages = sanitize_messages_for_anthropic(&req.messages);
615                let messages = if messages.is_empty() {
616                    serde_json::json!([{"role": "user", "content": req.prompt}])
617                } else {
618                    serde_json::to_value(messages).unwrap_or_else(
619                        |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
620                    )
621                };
622                let mut body = serde_json::json!({
623                    "model": req.model,
624                    "max_tokens": 4096,
625                    "system": req.system_prompt_body,
626                    "messages": messages,
627                });
628                if !tools.is_null() {
629                    body["tools"] = tools;
630                }
631
632                let mut last_err: Option<ProviderError> = None;
633
634                for attempt in 1..=retry.max_attempts {
635                    let send_fut = client
636                        .post(&endpoint)
637                        .header("x-api-key", &api_key)
638                        .header("anthropic-version", &anthropic_version)
639                        .header("content-type", "application/json")
640                        .json(&body)
641                        .send();
642
643                    let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
644                        Ok(Ok(r)) => r,
645                        Ok(Err(e)) => {
646                            last_err = Some(ProviderError::Network {
647                                message: e.to_string(),
648                            });
649                            if attempt < retry.max_attempts {
650                                let backoff_ms = (retry.base_delay.as_millis() as u64)
651                                    .saturating_mul(1u64 << (attempt - 1));
652                                tokio::time::sleep(Duration::from_millis(
653                                    backoff_ms.min(retry.max_delay.as_millis() as u64),
654                                ))
655                                .await;
656                                continue;
657                            }
658                            break;
659                        }
660                        Err(_) => {
661                            last_err = Some(ProviderError::Timeout {
662                                elapsed_ms: timeout.per_request.as_millis() as u64,
663                            });
664                            if attempt < retry.max_attempts {
665                                tokio::time::sleep(retry.base_delay).await;
666                                continue;
667                            }
668                            break;
669                        }
670                    };
671
672                    let status = resp.status().as_u16();
673                    if !resp.status().is_success() {
674                        let text = resp.text().await.unwrap_or_default();
675                        let err = ProviderError::Api {
676                            status,
677                            message: text,
678                        };
679                        last_err = Some(err);
680
681                        // Retry 5xx; fail fast otherwise.
682                        let retryable = (500..=599).contains(&status);
683                        if retryable && attempt < retry.max_attempts {
684                            tokio::time::sleep(retry.base_delay).await;
685                            continue;
686                        }
687                        break;
688                    }
689
690                    let json: serde_json::Value =
691                        resp.json().await.map_err(|e| ProviderError::Parse {
692                            message: e.to_string(),
693                        })?;
694
695                    let response = json["content"]
696                        .as_array()
697                        .map(|arr| {
698                            arr.iter()
699                                .filter_map(|block| block["text"].as_str())
700                                .collect::<Vec<_>>()
701                                .join("")
702                        })
703                        .unwrap_or_default();
704
705                    return Ok(CompletionResponse {
706                        system_prompt: req.system_prompt_name.clone(),
707                        response,
708                        tool_count: req.tools.len(),
709                        skill_count: req.skill_count,
710                        mcp_server_count: req.mcp_servers.len(),
711                    });
712                }
713
714                Err(match last_err {
715                    Some(e) => ProviderError::RetryExhausted {
716                        attempts: retry.max_attempts,
717                        last_error: e.to_string(),
718                    },
719                    None => ProviderError::Other {
720                        message: "request failed".to_string(),
721                    },
722                })
723            })
724        }
725
726        fn stream(&self, request: &CompletionRequest) -> EventStream {
727            let req = request.clone();
728            let api_key = self.api_key.clone();
729            let client = self.client.clone();
730            let endpoint = self.endpoint.clone();
731            let anthropic_version = self.anthropic_version.clone();
732            let timeout = self.timeout.clone();
733
734            let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(64);
735
736            tokio::spawn(async move {
737                let tools = build_anthropic_tools(&req.tools);
738                let messages = sanitize_messages_for_anthropic(&req.messages);
739                let messages = if messages.is_empty() {
740                    serde_json::json!([{"role": "user", "content": req.prompt}])
741                } else {
742                    serde_json::to_value(messages).unwrap_or_else(
743                        |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
744                    )
745                };
746                let mut body = serde_json::json!({
747                    "model": req.model,
748                    "max_tokens": 4096,
749                    "system": req.system_prompt_body,
750                    "messages": messages,
751                    "stream": true,
752                });
753                if !tools.is_null() {
754                    body["tools"] = tools;
755                }
756
757                let send_fut = client
758                    .post(&endpoint)
759                    .header("x-api-key", &api_key)
760                    .header("anthropic-version", &anthropic_version)
761                    .header("content-type", "application/json")
762                    .json(&body)
763                    .send();
764
765                let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
766                    Ok(Ok(r)) => r,
767                    Ok(Err(e)) => {
768                        let _ = tx
769                            .send(Err(ProviderError::Network {
770                                message: e.to_string(),
771                            }))
772                            .await;
773                        return;
774                    }
775                    Err(_) => {
776                        let _ = tx
777                            .send(Err(ProviderError::Timeout {
778                                elapsed_ms: timeout.per_request.as_millis() as u64,
779                            }))
780                            .await;
781                        return;
782                    }
783                };
784
785                if !resp.status().is_success() {
786                    let status = resp.status().as_u16();
787                    let text = resp.text().await.unwrap_or_default();
788                    let _ = tx
789                        .send(Err(ProviderError::Api {
790                            status,
791                            message: text,
792                        }))
793                        .await;
794                    return;
795                }
796
797                let mut buf = String::new();
798                let mut bytes = resp.bytes_stream();
799                use futures_util::StreamExt;
800
801                let mut current_tool_use: Option<(String, String, String)> = None;
802
803                while let Some(chunk) = bytes.next().await {
804                    let chunk = match chunk {
805                        Ok(c) => c,
806                        Err(e) => {
807                            let _ = tx
808                                .send(Err(ProviderError::Network {
809                                    message: e.to_string(),
810                                }))
811                                .await;
812                            return;
813                        }
814                    };
815
816                    buf.push_str(&String::from_utf8_lossy(&chunk));
817
818                    // SSE frames are separated by a blank line.
819                    while let Some(idx) = buf.find("\n\n") {
820                        let frame: String = buf.drain(..(idx + 2)).collect();
821                        let mut data_lines = Vec::new();
822                        for line in frame.lines() {
823                            let line = line.trim();
824                            if let Some(rest) = line.strip_prefix("data:") {
825                                let payload = rest.trim();
826                                if !payload.is_empty() {
827                                    data_lines.push(payload.to_string());
828                                }
829                            }
830                        }
831
832                        if data_lines.is_empty() {
833                            continue;
834                        }
835
836                        let data = data_lines.join("\n");
837                        if data == "[DONE]" {
838                            continue;
839                        }
840
841                        let Ok(event) = serde_json::from_str::<serde_json::Value>(&data) else {
842                            continue;
843                        };
844                        let typ = event["type"].as_str().unwrap_or("");
845                        match typ {
846                            "content_block_start" => {
847                                let cb_type = event["content_block"]["type"].as_str().unwrap_or("");
848                                if cb_type == "text" {
849                                    if let Some(t) = event["content_block"]["text"].as_str() {
850                                        if !t.is_empty() {
851                                            let _ = tx
852                                                .send(Ok(ApiEvent::MessageDelta {
853                                                    text: t.to_string(),
854                                                }))
855                                                .await;
856                                        }
857                                    }
858                                } else if cb_type == "tool_use" {
859                                    let id = event["content_block"]["id"]
860                                        .as_str()
861                                        .unwrap_or("")
862                                        .to_string();
863                                    let name = event["content_block"]["name"]
864                                        .as_str()
865                                        .unwrap_or("")
866                                        .to_string();
867                                    let input = event["content_block"]["input"].clone();
868                                    let input_str = if input.is_null() {
869                                        String::new()
870                                    } else {
871                                        serde_json::to_string(&input).unwrap_or_default()
872                                    };
873                                    current_tool_use = Some((id, name, input_str));
874                                } else if cb_type == "thinking" {
875                                    if let Some(t) = event["content_block"]["thinking"].as_str() {
876                                        if !t.is_empty() {
877                                            let _ = tx
878                                                .send(Ok(ApiEvent::ThinkingDelta {
879                                                    text: t.to_string(),
880                                                }))
881                                                .await;
882                                        }
883                                    }
884                                }
885                            }
886                            "content_block_delta" => {
887                                let delta_type = event["delta"]["type"].as_str().unwrap_or("");
888                                match delta_type {
889                                    "text_delta" => {
890                                        if let Some(t) = event["delta"]["text"].as_str() {
891                                            let _ = tx
892                                                .send(Ok(ApiEvent::MessageDelta {
893                                                    text: t.to_string(),
894                                                }))
895                                                .await;
896                                        }
897                                    }
898                                    "thinking_delta" => {
899                                        if let Some(t) = event["delta"]["thinking"].as_str() {
900                                            let _ = tx
901                                                .send(Ok(ApiEvent::ThinkingDelta {
902                                                    text: t.to_string(),
903                                                }))
904                                                .await;
905                                        }
906                                    }
907                                    "input_json_delta" => {
908                                        if let Some(partial) =
909                                            event["delta"]["partial_json"].as_str()
910                                        {
911                                            if let Some((_id, _name, input_buf)) =
912                                                current_tool_use.as_mut()
913                                            {
914                                                input_buf.push_str(partial);
915                                            }
916                                        }
917                                    }
918                                    _ => {}
919                                }
920                            }
921                            "content_block_stop" => {
922                                if let Some((id, name, input)) = current_tool_use.take() {
923                                    let input = if input.trim().is_empty() {
924                                        "{}".to_string()
925                                    } else {
926                                        input
927                                    };
928                                    let _ = tx
929                                        .send(Ok(ApiEvent::ToolUse {
930                                            tool_use: ToolUseEvent { id, name, input },
931                                        }))
932                                        .await;
933                                }
934                            }
935                            "message_delta" => {
936                                if let Some(usage) = event.get("usage") {
937                                    let _ = tx
938                                        .send(Ok(ApiEvent::Usage {
939                                            usage: UsageEvent {
940                                                input_tokens: usage["input_tokens"]
941                                                    .as_u64()
942                                                    .unwrap_or(0),
943                                                output_tokens: usage["output_tokens"]
944                                                    .as_u64()
945                                                    .unwrap_or(0),
946                                                cache_read_tokens: usage["cache_read_input_tokens"]
947                                                    .as_u64()
948                                                    .unwrap_or(0),
949                                                cache_write_tokens:
950                                                    usage["cache_creation_input_tokens"]
951                                                        .as_u64()
952                                                        .unwrap_or(0),
953                                            },
954                                        }))
955                                        .await;
956                                }
957                            }
958                            "message_stop" => {
959                                let _ = tx.send(Ok(ApiEvent::Completed)).await;
960                                return;
961                            }
962                            _ => {}
963                        }
964                    }
965                }
966
967                let _ = tx.send(Ok(ApiEvent::Completed)).await;
968            });
969
970            Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
971                rx.recv().await.map(|item| (item, rx))
972            }))
973        }
974    }
975}
976
977// --- Provider factory (picks provider from env) ---
978
979pub fn create_provider() -> BoxedProvider {
980    let provider_name = std::env::var("CLAWEDCODE_PROVIDER").unwrap_or_default();
981
982    match provider_name.as_str() {
983        #[cfg(feature = "anthropic")]
984        "anthropic" => {
985            if let Some(p) = anthropic_provider::AnthropicProvider::from_env() {
986                tracing::info!("Using Anthropic provider");
987                return Box::new(p);
988            }
989            tracing::warn!("ANTHROPIC_API_KEY not set, falling back to mock provider");
990        }
991        #[cfg(not(feature = "anthropic"))]
992        "anthropic" => {
993            tracing::warn!(
994                "Anthropic provider requested but 'anthropic' feature not enabled, falling back to mock"
995            );
996        }
997        _ => {}
998    }
999
1000    tracing::info!("Using mock provider");
1001    Box::new(MockProvider)
1002}
1003
1004// --- Helper: collect stream into response ---
1005
1006pub async fn collect_stream_to_response(
1007    stream: EventStream,
1008    request: &CompletionRequest,
1009) -> Result<CompletionResponse, ProviderError> {
1010    use futures_util::StreamExt;
1011    let mut text = String::new();
1012    let mut thinking = String::new();
1013
1014    let mut s = stream;
1015    while let Some(event) = s.next().await {
1016        match event? {
1017            ApiEvent::MessageDelta { text: t } => text.push_str(&t),
1018            ApiEvent::ThinkingDelta { text: t } => thinking.push_str(&t),
1019            ApiEvent::Usage { usage: _ } => {}
1020            ApiEvent::Completed => break,
1021            ApiEvent::ToolUse { .. } | ApiEvent::ToolResult { .. } => {}
1022        }
1023    }
1024
1025    Ok(CompletionResponse {
1026        system_prompt: request.system_prompt_name.clone(),
1027        response: text,
1028        tool_count: request.tools.len(),
1029        skill_count: request.skill_count,
1030        mcp_server_count: request.mcp_servers.len(),
1031    })
1032}
1033
1034#[cfg(test)]
1035mod tests {
1036    use super::*;
1037    use futures_util::StreamExt;
1038
1039    #[tokio::test]
1040    async fn mock_stream_has_multiple_deltas() {
1041        let provider = MockProvider;
1042        let request = CompletionRequest {
1043            model: "test-model".to_string(),
1044            prompt_pack: "default".to_string(),
1045            system_prompt_name: "default".to_string(),
1046            system_prompt_body: "You are helpful.".to_string(),
1047            prompt: "hello".to_string(),
1048            messages: vec![],
1049            tools: vec![],
1050            skill_count: 0,
1051            mcp_servers: BTreeMap::new(),
1052        };
1053
1054        let events: Vec<_> = provider
1055            .stream(&request)
1056            .filter_map(|e| async move { e.ok() })
1057            .collect()
1058            .await;
1059        assert!(!events.is_empty());
1060
1061        assert!(matches!(events.last(), Some(ApiEvent::Completed)));
1062    }
1063
1064    #[tokio::test]
1065    async fn mock_stream_event_ordering() {
1066        let provider = MockProvider;
1067        let request = CompletionRequest {
1068            model: "test-model".to_string(),
1069            prompt_pack: "default".to_string(),
1070            system_prompt_name: "default".to_string(),
1071            system_prompt_body: "You are helpful.".to_string(),
1072            prompt: "hello".to_string(),
1073            messages: vec![],
1074            tools: vec![],
1075            skill_count: 0,
1076            mcp_servers: BTreeMap::new(),
1077        };
1078
1079        let events: Vec<_> = provider
1080            .stream(&request)
1081            .filter_map(|e| async move { e.ok() })
1082            .collect()
1083            .await;
1084
1085        assert!(matches!(events[0], ApiEvent::ThinkingDelta { .. }));
1086        assert!(matches!(events[1], ApiEvent::MessageDelta { .. }));
1087
1088        let usage_idx = events
1089            .iter()
1090            .position(|e| matches!(e, ApiEvent::Usage { .. }))
1091            .expect("Usage event should exist");
1092        let completed_idx = events
1093            .iter()
1094            .position(|e| matches!(e, ApiEvent::Completed))
1095            .expect("Completed event should exist");
1096        assert!(usage_idx < completed_idx);
1097    }
1098
1099    #[tokio::test]
1100    async fn mock_stream_concatenated_text_matches_complete_response() {
1101        let provider = MockProvider;
1102        let request = CompletionRequest {
1103            model: "test-model".to_string(),
1104            prompt_pack: "default".to_string(),
1105            system_prompt_name: "default".to_string(),
1106            system_prompt_body: "You are helpful.".to_string(),
1107            prompt: "hello".to_string(),
1108            messages: vec![],
1109            tools: vec![],
1110            skill_count: 0,
1111            mcp_servers: BTreeMap::new(),
1112        };
1113
1114        let direct = provider.complete(&request).await.unwrap();
1115        let events: Vec<_> = provider
1116            .stream(&request)
1117            .filter_map(|e| async move { e.ok() })
1118            .collect()
1119            .await;
1120
1121        let mut text = String::new();
1122        for e in &events {
1123            if let ApiEvent::MessageDelta { text: t } = e {
1124                text.push_str(t);
1125            }
1126        }
1127
1128        assert_eq!(text, direct.response);
1129    }
1130
1131    #[test]
1132    fn usage_account_accumulates() {
1133        let mut account = UsageAccount::default();
1134        account.record(&UsageEvent {
1135            input_tokens: 100,
1136            output_tokens: 50,
1137            cache_read_tokens: 10,
1138            cache_write_tokens: 20,
1139        });
1140        account.record(&UsageEvent {
1141            input_tokens: 200,
1142            output_tokens: 75,
1143            cache_read_tokens: 0,
1144            cache_write_tokens: 0,
1145        });
1146
1147        assert_eq!(account.total_input_tokens, 300);
1148        assert_eq!(account.total_output_tokens, 125);
1149        assert_eq!(account.total_cache_read_tokens, 10);
1150        assert_eq!(account.total_cache_write_tokens, 20);
1151        assert_eq!(account.request_count, 2);
1152    }
1153
1154    #[test]
1155    fn provider_error_display() {
1156        let err = ProviderError::Timeout { elapsed_ms: 5000 };
1157        assert!(err.to_string().contains("5000"));
1158
1159        let err = ProviderError::RetryExhausted {
1160            attempts: 3,
1161            last_error: "timeout".to_string(),
1162        };
1163        assert!(err.to_string().contains("3"));
1164    }
1165}