Skip to main content

clawedcode_api/
lib.rs

1use clawedcode_mcp::McpServerConfig;
2use clawedcode_tools::ToolSpec;
3use futures_core::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::BTreeMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::time::Duration;
9use tokio::sync::mpsc;
10
11// --- Request / Response types ---
12
13#[derive(Debug, Clone)]
14pub struct CompletionRequest {
15    pub model: String,
16    pub prompt_pack: String,
17    pub system_prompt_name: String,
18    pub system_prompt_body: String,
19    pub prompt: String,
20    /// Structured conversation history. Providers may prefer this over `prompt`.
21    pub messages: Vec<ProviderMessage>,
22    pub tools: Vec<ToolSpec>,
23    pub skill_count: usize,
24    pub mcp_servers: BTreeMap<String, McpServerConfig>,
25}
26
27#[derive(Debug, Clone, Serialize)]
28pub struct CompletionResponse {
29    pub system_prompt: String,
30    pub response: String,
31    pub tool_count: usize,
32    pub skill_count: usize,
33    pub mcp_server_count: usize,
34}
35
36// --- Conversation message model (provider-facing) ---
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum ProviderRole {
41    User,
42    Assistant,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46pub struct ProviderMessage {
47    pub role: ProviderRole,
48    pub content: Vec<ProviderContentBlock>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum ProviderContentBlock {
54    Text {
55        text: String,
56    },
57    ToolUse {
58        id: String,
59        name: String,
60        #[serde(default)]
61        input: serde_json::Value,
62    },
63    ToolResult {
64        tool_use_id: String,
65        content: String,
66        #[serde(default)]
67        is_error: bool,
68    },
69    Thinking {
70        thinking: String,
71    },
72}
73
74// --- Streaming event model ---
75
76#[derive(Debug, Clone, Serialize)]
77pub struct ToolUseEvent {
78    pub id: String,
79    pub name: String,
80    pub input: String,
81}
82
83#[derive(Debug, Clone, Serialize)]
84pub struct ToolResultEvent {
85    pub tool_use_id: String,
86    pub content: String,
87    pub is_error: bool,
88}
89
90#[derive(Debug, Clone, Serialize)]
91pub struct UsageEvent {
92    pub input_tokens: u64,
93    pub output_tokens: u64,
94    pub cache_read_tokens: u64,
95    pub cache_write_tokens: u64,
96}
97
98#[derive(Debug, Clone, Serialize)]
99pub enum ApiEvent {
100    MessageDelta { text: String },
101    ThinkingDelta { text: String },
102    ToolUse { tool_use: ToolUseEvent },
103    ToolResult { tool_result: ToolResultEvent },
104    Usage { usage: UsageEvent },
105    Completed,
106}
107
108// --- Error types ---
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub enum ProviderError {
112    Network { message: String },
113    Api { status: u16, message: String },
114    Parse { message: String },
115    Timeout { elapsed_ms: u64 },
116    RetryExhausted { attempts: u32, last_error: String },
117    Other { message: String },
118}
119
120impl std::fmt::Display for ProviderError {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        match self {
123            ProviderError::Network { message } => write!(f, "Network error: {message}"),
124            ProviderError::Api { status, message } => write!(f, "API error ({status}): {message}"),
125            ProviderError::Parse { message } => write!(f, "Parse error: {message}"),
126            ProviderError::Timeout { elapsed_ms } => write!(f, "Timeout after {elapsed_ms}ms"),
127            ProviderError::RetryExhausted {
128                attempts,
129                last_error,
130            } => {
131                write!(f, "Retry exhausted after {attempts} attempts: {last_error}")
132            }
133            ProviderError::Other { message } => write!(f, "{message}"),
134        }
135    }
136}
137
138impl std::error::Error for ProviderError {}
139
140// --- Usage accounting ---
141
142#[derive(Debug, Clone, Default, Serialize)]
143pub struct UsageAccount {
144    pub total_input_tokens: u64,
145    pub total_output_tokens: u64,
146    pub total_cache_read_tokens: u64,
147    pub total_cache_write_tokens: u64,
148    pub request_count: u64,
149}
150
151impl UsageAccount {
152    pub fn record(&mut self, usage: &UsageEvent) {
153        self.total_input_tokens += usage.input_tokens;
154        self.total_output_tokens += usage.output_tokens;
155        self.total_cache_read_tokens += usage.cache_read_tokens;
156        self.total_cache_write_tokens += usage.cache_write_tokens;
157        self.request_count += 1;
158    }
159}
160
161// --- Retry / Timeout envelope ---
162
163#[derive(Debug, Clone)]
164pub struct RetryConfig {
165    pub max_attempts: u32,
166    pub base_delay: Duration,
167    pub max_delay: Duration,
168}
169
170impl Default for RetryConfig {
171    fn default() -> Self {
172        Self {
173            max_attempts: 3,
174            base_delay: Duration::from_millis(200),
175            max_delay: Duration::from_secs(5),
176        }
177    }
178}
179
180#[derive(Debug, Clone)]
181pub struct TimeoutConfig {
182    pub per_request: Duration,
183}
184
185impl Default for TimeoutConfig {
186    fn default() -> Self {
187        Self {
188            per_request: Duration::from_secs(60),
189        }
190    }
191}
192
193// --- Provider trait (async, streaming) ---
194
195pub type EventStream = Pin<Box<dyn Stream<Item = Result<ApiEvent, ProviderError>> + Send>>;
196
197pub trait Provider: Send + Sync {
198    fn complete(
199        &self,
200        request: &CompletionRequest,
201    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>;
202
203    fn stream(&self, request: &CompletionRequest) -> EventStream;
204}
205
206// --- Boxed provider handle ---
207
208pub type BoxedProvider = Box<dyn Provider>;
209
210// --- Mock provider (streams with delays) ---
211
212#[derive(Debug, Default, Clone)]
213pub struct MockProvider;
214
215impl Provider for MockProvider {
216    fn complete(
217        &self,
218        request: &CompletionRequest,
219    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
220        let response = mock_complete_response(request);
221        Box::pin(async move { Ok(response) })
222    }
223
224    fn stream(&self, request: &CompletionRequest) -> EventStream {
225        let req = request.clone();
226        let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
227
228        tokio::spawn(async move {
229            let events = mock_stream_events(&req);
230            for (idx, event) in events.into_iter().enumerate() {
231                if idx > 0 {
232                    tokio::time::sleep(Duration::from_millis(18)).await;
233                }
234                if tx.send(Ok(event)).await.is_err() {
235                    return;
236                }
237            }
238        });
239
240        Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
241            rx.recv().await.map(|item| (item, rx))
242        }))
243    }
244}
245
246fn mock_complete_response(request: &CompletionRequest) -> CompletionResponse {
247    if wants_read_cargo_toml(&request.prompt) {
248        if let Some(tool_content) = first_tool_result_content(&request.messages) {
249            let response = mock_summarize_cargo_toml(&tool_content);
250            return CompletionResponse {
251                system_prompt: request.system_prompt_name.clone(),
252                response,
253                tool_count: request.tools.len(),
254                skill_count: request.skill_count,
255                mcp_server_count: request.mcp_servers.len(),
256            };
257        }
258
259        return CompletionResponse {
260            system_prompt: request.system_prompt_name.clone(),
261            response: "I'll read Cargo.toml first.".to_string(),
262            tool_count: request.tools.len(),
263            skill_count: request.skill_count,
264            mcp_server_count: request.mcp_servers.len(),
265        };
266    }
267
268    let response = format!(
269        "Model: {}\nPrompt pack: {}\nSystem prompt: {}\nTools: {}\nSkills discovered: {}\nMCP servers discovered: {}\n\nRequest queued for the execution loop.\n\nNext priorities:\n1. Parse instructions into an explicit task graph.\n2. Resolve tool approvals before execution.\n3. Stream structured updates into the terminal UI.",
270        request.model,
271        request.prompt_pack,
272        request.system_prompt_name,
273        request
274            .tools
275            .iter()
276            .map(|tool| tool.name.as_str())
277            .collect::<Vec<_>>()
278            .join(", "),
279        request.skill_count,
280        request.mcp_servers.len(),
281    );
282
283    CompletionResponse {
284        system_prompt: request.system_prompt_name.clone(),
285        response,
286        tool_count: request.tools.len(),
287        skill_count: request.skill_count,
288        mcp_server_count: request.mcp_servers.len(),
289    }
290}
291
292fn mock_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
293    if wants_read_cargo_toml(&request.prompt) {
294        if let Some(tool_content) = first_tool_result_content(&request.messages) {
295            return vec![
296                ApiEvent::ThinkingDelta {
297                    text: "I have the Cargo.toml contents; summarizing.".to_string(),
298                },
299                ApiEvent::MessageDelta {
300                    text: mock_summarize_cargo_toml(&tool_content),
301                },
302                ApiEvent::Usage {
303                    usage: UsageEvent {
304                        input_tokens: 220,
305                        output_tokens: 90,
306                        cache_read_tokens: 0,
307                        cache_write_tokens: 0,
308                    },
309                },
310                ApiEvent::Completed,
311            ];
312        }
313
314        return vec![
315            ApiEvent::ThinkingDelta {
316                text: "I should read Cargo.toml to answer this.".to_string(),
317            },
318            ApiEvent::ToolUse {
319                tool_use: ToolUseEvent {
320                    id: "tool_1".to_string(),
321                    name: "read_file".to_string(),
322                    input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
323                },
324            },
325            ApiEvent::Completed,
326        ];
327    }
328
329    let tool_names: Vec<&str> = request.tools.iter().map(|t| t.name.as_ref()).collect();
330
331    vec![
332        ApiEvent::ThinkingDelta {
333            text: "Let me start by analyzing the request.".to_string(),
334        },
335        ApiEvent::MessageDelta {
336            text: format!(
337                "Model: {}\nPrompt pack: {}\nSystem prompt: {}\nTools: {}\nSkills discovered: {}\nMCP servers discovered: {}\n",
338                request.model,
339                request.prompt_pack,
340                request.system_prompt_name,
341                tool_names.join(", "),
342                request.skill_count,
343                request.mcp_servers.len(),
344            ),
345        },
346        ApiEvent::MessageDelta {
347            text: "\nRequest queued for the execution loop.\n\nNext priorities:\n".to_string(),
348        },
349        ApiEvent::MessageDelta {
350            text: "1. Parse instructions into an explicit task graph.\n".to_string(),
351        },
352        ApiEvent::MessageDelta {
353            text: "2. Resolve tool approvals before execution.\n".to_string(),
354        },
355        ApiEvent::MessageDelta {
356            text: "3. Stream structured updates into the terminal UI.".to_string(),
357        },
358        ApiEvent::Usage {
359            usage: UsageEvent {
360                input_tokens: 120,
361                output_tokens: 85,
362                cache_read_tokens: 0,
363                cache_write_tokens: 0,
364            },
365        },
366        ApiEvent::Completed,
367    ]
368}
369
370fn wants_read_cargo_toml(prompt: &str) -> bool {
371    let p = prompt.to_ascii_lowercase();
372    p.contains("cargo.toml") && (p.contains("read") || p.contains("summarize"))
373}
374
375fn first_tool_result_content(messages: &[ProviderMessage]) -> Option<String> {
376    for m in messages {
377        for b in &m.content {
378            if let ProviderContentBlock::ToolResult { content, .. } = b {
379                return Some(content.clone());
380            }
381        }
382    }
383    None
384}
385
386fn mock_summarize_cargo_toml(contents: &str) -> String {
387    if contents.contains("[workspace]") {
388        let mut out = String::from("Cargo.toml defines a Rust workspace.\n");
389        if contents.contains("members") {
390            out.push_str("It declares workspace members; this repo is a multi-crate workspace.\n");
391        }
392        out.push_str("Key crates include: clawedcode (cli), clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, clawedcode-tui.");
393        out
394    } else {
395        "Cargo.toml does not look like a workspace manifest (no [workspace] section).".to_string()
396    }
397}
398
399// --- MockToolProvider: deterministic provider for tests that triggers a tool call ---
400
401/// A mock provider that, on the first call, emits a `read_file` ToolUse for
402/// `Cargo.toml`. On subsequent calls (i.e. after the runtime has appended a
403/// tool_result message), it returns a final text response that references the
404/// tool result.
405///
406/// Detection of "subsequent call" is done by checking whether the session
407/// already contains a `tool` role message (injected by the runtime after tool
408/// execution).
409#[derive(Debug, Default, Clone)]
410pub struct MockToolProvider;
411
412impl Provider for MockToolProvider {
413    fn complete(
414        &self,
415        request: &CompletionRequest,
416    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
417        let response = mock_tool_complete_response(request);
418        Box::pin(async move { Ok(response) })
419    }
420
421    fn stream(&self, request: &CompletionRequest) -> EventStream {
422        let req = request.clone();
423        let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
424
425        tokio::spawn(async move {
426            let events = mock_tool_stream_events(&req);
427            for (idx, event) in events.into_iter().enumerate() {
428                if idx > 0 {
429                    tokio::time::sleep(Duration::from_millis(5)).await;
430                }
431                if tx.send(Ok(event)).await.is_err() {
432                    return;
433                }
434            }
435        });
436
437        Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
438            rx.recv().await.map(|item| (item, rx))
439        }))
440    }
441}
442
443fn mock_tool_complete_response(request: &CompletionRequest) -> CompletionResponse {
444    if has_tool_result_message(request) {
445        let response = "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string();
446        return CompletionResponse {
447            system_prompt: request.system_prompt_name.clone(),
448            response,
449            tool_count: request.tools.len(),
450            skill_count: request.skill_count,
451            mcp_server_count: request.mcp_servers.len(),
452        };
453    }
454
455    CompletionResponse {
456        system_prompt: request.system_prompt_name.clone(),
457        response: "I'll read the Cargo.toml file.".to_string(),
458        tool_count: request.tools.len(),
459        skill_count: request.skill_count,
460        mcp_server_count: request.mcp_servers.len(),
461    }
462}
463
464fn mock_tool_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
465    if has_tool_result_message(request) {
466        return vec![
467            ApiEvent::ThinkingDelta {
468                text: "I have the file contents now.".to_string(),
469            },
470            ApiEvent::MessageDelta {
471                text: "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string(),
472            },
473            ApiEvent::Usage {
474                usage: UsageEvent {
475                    input_tokens: 200,
476                    output_tokens: 60,
477                    cache_read_tokens: 0,
478                    cache_write_tokens: 0,
479                },
480            },
481            ApiEvent::Completed,
482        ];
483    }
484
485    vec![
486        ApiEvent::ThinkingDelta {
487            text: "I should read the Cargo.toml file.".to_string(),
488        },
489        ApiEvent::ToolUse {
490            tool_use: ToolUseEvent {
491                id: "tool_1".to_string(),
492                name: "read_file".to_string(),
493                input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
494            },
495        },
496        ApiEvent::Usage {
497            usage: UsageEvent {
498                input_tokens: 100,
499                output_tokens: 30,
500                cache_read_tokens: 0,
501                cache_write_tokens: 0,
502            },
503        },
504        ApiEvent::Completed,
505    ]
506}
507
508/// Returns true if the conversation already contains tool result blocks,
509/// indicating this is a re-query after tool execution.
510fn has_tool_result_message(request: &CompletionRequest) -> bool {
511    request.messages.iter().any(|m| {
512        m.content
513            .iter()
514            .any(|b| matches!(b, ProviderContentBlock::ToolResult { .. }))
515    })
516}
517
518// --- Optional Anthropic provider (behind feature flag + env var) ---
519
520#[cfg(feature = "anthropic")]
521pub mod anthropic_provider {
522    use super::*;
523    use reqwest::Client;
524
525    #[derive(Debug, Clone)]
526    pub struct AnthropicProvider {
527        client: Client,
528        api_key: String,
529        endpoint: String,
530        anthropic_version: String,
531        retry: RetryConfig,
532        timeout: TimeoutConfig,
533    }
534
535    pub(crate) fn normalize_anthropic_endpoint(endpoint: &str) -> String {
536        if endpoint.contains("/v1/messages") {
537            endpoint.to_string()
538        } else {
539            let endpoint = endpoint.trim_end_matches('/');
540            format!("{}/v1/messages", endpoint)
541        }
542    }
543
544    impl AnthropicProvider {
545        pub fn from_env() -> Option<Self> {
546            let api_key = std::env::var("ANTHROPIC_API_KEY")
547                .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
548                .ok()?;
549            let endpoint = std::env::var("CLAWEDCODE_ANTHROPIC_ENDPOINT")
550                .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
551                .unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string());
552            let endpoint = normalize_anthropic_endpoint(&endpoint);
553            let anthropic_version = std::env::var("CLAWEDCODE_ANTHROPIC_VERSION")
554                .unwrap_or_else(|_| "2023-06-01".to_string());
555            Some(Self {
556                client: Client::new(),
557                api_key,
558                endpoint,
559                anthropic_version,
560                retry: RetryConfig::default(),
561                timeout: TimeoutConfig::default(),
562            })
563        }
564
565        pub fn new(api_key: String, endpoint: String) -> Self {
566            Self {
567                client: Client::new(),
568                api_key,
569                endpoint,
570                anthropic_version: "2023-06-01".to_string(),
571                retry: RetryConfig::default(),
572                timeout: TimeoutConfig::default(),
573            }
574        }
575    }
576
577    fn build_anthropic_tools(tools: &[ToolSpec]) -> serde_json::Value {
578        if tools.is_empty() {
579            return serde_json::Value::Null;
580        }
581        serde_json::Value::Array(
582            tools
583                .iter()
584                .map(|t| {
585                    serde_json::json!({
586                        "name": t.name,
587                        "description": t.description,
588                        "input_schema": t.input_schema,
589                    })
590                })
591                .collect(),
592        )
593    }
594
595    fn sanitize_messages_for_anthropic(messages: &[ProviderMessage]) -> Vec<ProviderMessage> {
596        messages
597            .iter()
598            .map(|m| ProviderMessage {
599                role: m.role.clone(),
600                content: m
601                    .content
602                    .iter()
603                    .filter(|b| !matches!(b, ProviderContentBlock::Thinking { .. }))
604                    .cloned()
605                    .collect(),
606            })
607            .filter(|m| !m.content.is_empty())
608            .collect()
609    }
610
611    impl Provider for AnthropicProvider {
612        fn complete(
613            &self,
614            request: &CompletionRequest,
615        ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>
616        {
617            let req = request.clone();
618            let api_key = self.api_key.clone();
619            let client = self.client.clone();
620            let endpoint = self.endpoint.clone();
621            let anthropic_version = self.anthropic_version.clone();
622            let retry = self.retry.clone();
623            let timeout = self.timeout.clone();
624            Box::pin(async move {
625                let tools = build_anthropic_tools(&req.tools);
626                let messages = sanitize_messages_for_anthropic(&req.messages);
627                let messages = if messages.is_empty() {
628                    serde_json::json!([{"role": "user", "content": req.prompt}])
629                } else {
630                    serde_json::to_value(messages).unwrap_or_else(
631                        |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
632                    )
633                };
634                let mut body = serde_json::json!({
635                    "model": req.model,
636                    "max_tokens": 4096,
637                    "system": req.system_prompt_body,
638                    "messages": messages,
639                });
640                if !tools.is_null() {
641                    body["tools"] = tools;
642                }
643
644                let mut last_err: Option<ProviderError> = None;
645
646                for attempt in 1..=retry.max_attempts {
647                    let send_fut = client
648                        .post(&endpoint)
649                        .header("x-api-key", &api_key)
650                        .header("anthropic-version", &anthropic_version)
651                        .header("content-type", "application/json")
652                        .json(&body)
653                        .send();
654
655                    let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
656                        Ok(Ok(r)) => r,
657                        Ok(Err(e)) => {
658                            last_err = Some(ProviderError::Network {
659                                message: e.to_string(),
660                            });
661                            if attempt < retry.max_attempts {
662                                let backoff_ms = (retry.base_delay.as_millis() as u64)
663                                    .saturating_mul(1u64 << (attempt - 1));
664                                tokio::time::sleep(Duration::from_millis(
665                                    backoff_ms.min(retry.max_delay.as_millis() as u64),
666                                ))
667                                .await;
668                                continue;
669                            }
670                            break;
671                        }
672                        Err(_) => {
673                            last_err = Some(ProviderError::Timeout {
674                                elapsed_ms: timeout.per_request.as_millis() as u64,
675                            });
676                            if attempt < retry.max_attempts {
677                                tokio::time::sleep(retry.base_delay).await;
678                                continue;
679                            }
680                            break;
681                        }
682                    };
683
684                    let status = resp.status().as_u16();
685                    if !resp.status().is_success() {
686                        let text = resp.text().await.unwrap_or_default();
687                        let err = ProviderError::Api {
688                            status,
689                            message: text,
690                        };
691                        last_err = Some(err);
692
693                        // Retry 5xx; fail fast otherwise.
694                        let retryable = (500..=599).contains(&status);
695                        if retryable && attempt < retry.max_attempts {
696                            tokio::time::sleep(retry.base_delay).await;
697                            continue;
698                        }
699                        break;
700                    }
701
702                    let json: serde_json::Value =
703                        resp.json().await.map_err(|e| ProviderError::Parse {
704                            message: e.to_string(),
705                        })?;
706
707                    let response = json["content"]
708                        .as_array()
709                        .map(|arr| {
710                            arr.iter()
711                                .filter_map(|block| block["text"].as_str())
712                                .collect::<Vec<_>>()
713                                .join("")
714                        })
715                        .unwrap_or_default();
716
717                    return Ok(CompletionResponse {
718                        system_prompt: req.system_prompt_name.clone(),
719                        response,
720                        tool_count: req.tools.len(),
721                        skill_count: req.skill_count,
722                        mcp_server_count: req.mcp_servers.len(),
723                    });
724                }
725
726                Err(match last_err {
727                    Some(e) => ProviderError::RetryExhausted {
728                        attempts: retry.max_attempts,
729                        last_error: e.to_string(),
730                    },
731                    None => ProviderError::Other {
732                        message: "request failed".to_string(),
733                    },
734                })
735            })
736        }
737
738        fn stream(&self, request: &CompletionRequest) -> EventStream {
739            let req = request.clone();
740            let api_key = self.api_key.clone();
741            let client = self.client.clone();
742            let endpoint = self.endpoint.clone();
743            let anthropic_version = self.anthropic_version.clone();
744            let timeout = self.timeout.clone();
745
746            let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(64);
747
748            tokio::spawn(async move {
749                let tools = build_anthropic_tools(&req.tools);
750                let messages = sanitize_messages_for_anthropic(&req.messages);
751                let messages = if messages.is_empty() {
752                    serde_json::json!([{"role": "user", "content": req.prompt}])
753                } else {
754                    serde_json::to_value(messages).unwrap_or_else(
755                        |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
756                    )
757                };
758                let mut body = serde_json::json!({
759                    "model": req.model,
760                    "max_tokens": 4096,
761                    "system": req.system_prompt_body,
762                    "messages": messages,
763                    "stream": true,
764                });
765                if !tools.is_null() {
766                    body["tools"] = tools;
767                }
768
769                let send_fut = client
770                    .post(&endpoint)
771                    .header("x-api-key", &api_key)
772                    .header("anthropic-version", &anthropic_version)
773                    .header("content-type", "application/json")
774                    .json(&body)
775                    .send();
776
777                let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
778                    Ok(Ok(r)) => r,
779                    Ok(Err(e)) => {
780                        let _ = tx
781                            .send(Err(ProviderError::Network {
782                                message: e.to_string(),
783                            }))
784                            .await;
785                        return;
786                    }
787                    Err(_) => {
788                        let _ = tx
789                            .send(Err(ProviderError::Timeout {
790                                elapsed_ms: timeout.per_request.as_millis() as u64,
791                            }))
792                            .await;
793                        return;
794                    }
795                };
796
797                if !resp.status().is_success() {
798                    let status = resp.status().as_u16();
799                    let text = resp.text().await.unwrap_or_default();
800                    let _ = tx
801                        .send(Err(ProviderError::Api {
802                            status,
803                            message: text,
804                        }))
805                        .await;
806                    return;
807                }
808
809                let mut buf = String::new();
810                let mut bytes = resp.bytes_stream();
811                use futures_util::StreamExt;
812
813                let mut current_tool_use: Option<(String, String, String)> = None;
814
815                while let Some(chunk) = bytes.next().await {
816                    let chunk = match chunk {
817                        Ok(c) => c,
818                        Err(e) => {
819                            let _ = tx
820                                .send(Err(ProviderError::Network {
821                                    message: e.to_string(),
822                                }))
823                                .await;
824                            return;
825                        }
826                    };
827
828                    buf.push_str(&String::from_utf8_lossy(&chunk));
829
830                    // SSE frames are separated by a blank line.
831                    while let Some(idx) = buf.find("\n\n") {
832                        let frame: String = buf.drain(..(idx + 2)).collect();
833                        let mut data_lines = Vec::new();
834                        for line in frame.lines() {
835                            let line = line.trim();
836                            if let Some(rest) = line.strip_prefix("data:") {
837                                let payload = rest.trim();
838                                if !payload.is_empty() {
839                                    data_lines.push(payload.to_string());
840                                }
841                            }
842                        }
843
844                        if data_lines.is_empty() {
845                            continue;
846                        }
847
848                        let data = data_lines.join("\n");
849                        if data == "[DONE]" {
850                            continue;
851                        }
852
853                        let Ok(event) = serde_json::from_str::<serde_json::Value>(&data) else {
854                            continue;
855                        };
856                        let typ = event["type"].as_str().unwrap_or("");
857                        match typ {
858                            "content_block_start" => {
859                                let cb_type = event["content_block"]["type"].as_str().unwrap_or("");
860                                if cb_type == "text" {
861                                    if let Some(t) = event["content_block"]["text"].as_str() {
862                                        if !t.is_empty() {
863                                            let _ = tx
864                                                .send(Ok(ApiEvent::MessageDelta {
865                                                    text: t.to_string(),
866                                                }))
867                                                .await;
868                                        }
869                                    }
870                                } else if cb_type == "tool_use" {
871                                    let id = event["content_block"]["id"]
872                                        .as_str()
873                                        .unwrap_or("")
874                                        .to_string();
875                                    let name = event["content_block"]["name"]
876                                        .as_str()
877                                        .unwrap_or("")
878                                        .to_string();
879                                    let input = event["content_block"]["input"].clone();
880                                    let input_str = if input.is_null() {
881                                        String::new()
882                                    } else {
883                                        serde_json::to_string(&input).unwrap_or_default()
884                                    };
885                                    current_tool_use = Some((id, name, input_str));
886                                } else if cb_type == "thinking" {
887                                    if let Some(t) = event["content_block"]["thinking"].as_str() {
888                                        if !t.is_empty() {
889                                            let _ = tx
890                                                .send(Ok(ApiEvent::ThinkingDelta {
891                                                    text: t.to_string(),
892                                                }))
893                                                .await;
894                                        }
895                                    }
896                                }
897                            }
898                            "content_block_delta" => {
899                                let delta_type = event["delta"]["type"].as_str().unwrap_or("");
900                                match delta_type {
901                                    "text_delta" => {
902                                        if let Some(t) = event["delta"]["text"].as_str() {
903                                            let _ = tx
904                                                .send(Ok(ApiEvent::MessageDelta {
905                                                    text: t.to_string(),
906                                                }))
907                                                .await;
908                                        }
909                                    }
910                                    "thinking_delta" => {
911                                        if let Some(t) = event["delta"]["thinking"].as_str() {
912                                            let _ = tx
913                                                .send(Ok(ApiEvent::ThinkingDelta {
914                                                    text: t.to_string(),
915                                                }))
916                                                .await;
917                                        }
918                                    }
919                                    "input_json_delta" => {
920                                        if let Some(partial) =
921                                            event["delta"]["partial_json"].as_str()
922                                        {
923                                            if let Some((_id, _name, input_buf)) =
924                                                current_tool_use.as_mut()
925                                            {
926                                                input_buf.push_str(partial);
927                                            }
928                                        }
929                                    }
930                                    _ => {}
931                                }
932                            }
933                            "content_block_stop" => {
934                                if let Some((id, name, input)) = current_tool_use.take() {
935                                    let input = if input.trim().is_empty() {
936                                        "{}".to_string()
937                                    } else {
938                                        input
939                                    };
940                                    let _ = tx
941                                        .send(Ok(ApiEvent::ToolUse {
942                                            tool_use: ToolUseEvent { id, name, input },
943                                        }))
944                                        .await;
945                                }
946                            }
947                            "message_delta" => {
948                                if let Some(usage) = event.get("usage") {
949                                    let _ = tx
950                                        .send(Ok(ApiEvent::Usage {
951                                            usage: UsageEvent {
952                                                input_tokens: usage["input_tokens"]
953                                                    .as_u64()
954                                                    .unwrap_or(0),
955                                                output_tokens: usage["output_tokens"]
956                                                    .as_u64()
957                                                    .unwrap_or(0),
958                                                cache_read_tokens: usage["cache_read_input_tokens"]
959                                                    .as_u64()
960                                                    .unwrap_or(0),
961                                                cache_write_tokens:
962                                                    usage["cache_creation_input_tokens"]
963                                                        .as_u64()
964                                                        .unwrap_or(0),
965                                            },
966                                        }))
967                                        .await;
968                                }
969                            }
970                            "message_stop" => {
971                                let _ = tx.send(Ok(ApiEvent::Completed)).await;
972                                return;
973                            }
974                            _ => {}
975                        }
976                    }
977                }
978
979                let _ = tx.send(Ok(ApiEvent::Completed)).await;
980            });
981
982            Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
983                rx.recv().await.map(|item| (item, rx))
984            }))
985        }
986    }
987}
988
989// --- Provider factory (picks provider from env) ---
990
991pub fn create_provider() -> BoxedProvider {
992    let provider_name = std::env::var("CLAWEDCODE_PROVIDER").unwrap_or_default();
993
994    match provider_name.as_str() {
995        #[cfg(feature = "anthropic")]
996        "anthropic" => {
997            if let Some(p) = anthropic_provider::AnthropicProvider::from_env() {
998                tracing::info!("Using Anthropic provider");
999                return Box::new(p);
1000            }
1001            tracing::warn!("ANTHROPIC_API_KEY not set, falling back to mock provider");
1002        }
1003        #[cfg(not(feature = "anthropic"))]
1004        "anthropic" => {
1005            tracing::warn!(
1006                "Anthropic provider requested but 'anthropic' feature not enabled, falling back to mock"
1007            );
1008        }
1009        _ => {}
1010    }
1011
1012    tracing::info!("Using mock provider");
1013    Box::new(MockProvider)
1014}
1015
1016// --- Helper: collect stream into response ---
1017
1018pub async fn collect_stream_to_response(
1019    stream: EventStream,
1020    request: &CompletionRequest,
1021) -> Result<CompletionResponse, ProviderError> {
1022    use futures_util::StreamExt;
1023    let mut text = String::new();
1024    let mut thinking = String::new();
1025
1026    let mut s = stream;
1027    while let Some(event) = s.next().await {
1028        match event? {
1029            ApiEvent::MessageDelta { text: t } => text.push_str(&t),
1030            ApiEvent::ThinkingDelta { text: t } => thinking.push_str(&t),
1031            ApiEvent::Usage { usage: _ } => {}
1032            ApiEvent::Completed => break,
1033            ApiEvent::ToolUse { .. } | ApiEvent::ToolResult { .. } => {}
1034        }
1035    }
1036
1037    Ok(CompletionResponse {
1038        system_prompt: request.system_prompt_name.clone(),
1039        response: text,
1040        tool_count: request.tools.len(),
1041        skill_count: request.skill_count,
1042        mcp_server_count: request.mcp_servers.len(),
1043    })
1044}
1045
1046#[cfg(test)]
1047mod tests {
1048    use super::*;
1049    use futures_util::StreamExt;
1050
1051    #[tokio::test]
1052    async fn mock_stream_has_multiple_deltas() {
1053        let provider = MockProvider;
1054        let request = CompletionRequest {
1055            model: "test-model".to_string(),
1056            prompt_pack: "default".to_string(),
1057            system_prompt_name: "default".to_string(),
1058            system_prompt_body: "You are helpful.".to_string(),
1059            prompt: "hello".to_string(),
1060            messages: vec![],
1061            tools: vec![],
1062            skill_count: 0,
1063            mcp_servers: BTreeMap::new(),
1064        };
1065
1066        let events: Vec<_> = provider
1067            .stream(&request)
1068            .filter_map(|e| async move { e.ok() })
1069            .collect()
1070            .await;
1071        assert!(!events.is_empty());
1072
1073        assert!(matches!(events.last(), Some(ApiEvent::Completed)));
1074    }
1075
1076    #[tokio::test]
1077    async fn mock_stream_event_ordering() {
1078        let provider = MockProvider;
1079        let request = CompletionRequest {
1080            model: "test-model".to_string(),
1081            prompt_pack: "default".to_string(),
1082            system_prompt_name: "default".to_string(),
1083            system_prompt_body: "You are helpful.".to_string(),
1084            prompt: "hello".to_string(),
1085            messages: vec![],
1086            tools: vec![],
1087            skill_count: 0,
1088            mcp_servers: BTreeMap::new(),
1089        };
1090
1091        let events: Vec<_> = provider
1092            .stream(&request)
1093            .filter_map(|e| async move { e.ok() })
1094            .collect()
1095            .await;
1096
1097        assert!(matches!(events[0], ApiEvent::ThinkingDelta { .. }));
1098        assert!(matches!(events[1], ApiEvent::MessageDelta { .. }));
1099
1100        let usage_idx = events
1101            .iter()
1102            .position(|e| matches!(e, ApiEvent::Usage { .. }))
1103            .expect("Usage event should exist");
1104        let completed_idx = events
1105            .iter()
1106            .position(|e| matches!(e, ApiEvent::Completed))
1107            .expect("Completed event should exist");
1108        assert!(usage_idx < completed_idx);
1109    }
1110
1111    #[tokio::test]
1112    async fn mock_stream_concatenated_text_matches_complete_response() {
1113        let provider = MockProvider;
1114        let request = CompletionRequest {
1115            model: "test-model".to_string(),
1116            prompt_pack: "default".to_string(),
1117            system_prompt_name: "default".to_string(),
1118            system_prompt_body: "You are helpful.".to_string(),
1119            prompt: "hello".to_string(),
1120            messages: vec![],
1121            tools: vec![],
1122            skill_count: 0,
1123            mcp_servers: BTreeMap::new(),
1124        };
1125
1126        let direct = provider.complete(&request).await.unwrap();
1127        let events: Vec<_> = provider
1128            .stream(&request)
1129            .filter_map(|e| async move { e.ok() })
1130            .collect()
1131            .await;
1132
1133        let mut text = String::new();
1134        for e in &events {
1135            if let ApiEvent::MessageDelta { text: t } = e {
1136                text.push_str(t);
1137            }
1138        }
1139
1140        assert_eq!(text, direct.response);
1141    }
1142
1143    #[test]
1144    fn usage_account_accumulates() {
1145        let mut account = UsageAccount::default();
1146        account.record(&UsageEvent {
1147            input_tokens: 100,
1148            output_tokens: 50,
1149            cache_read_tokens: 10,
1150            cache_write_tokens: 20,
1151        });
1152        account.record(&UsageEvent {
1153            input_tokens: 200,
1154            output_tokens: 75,
1155            cache_read_tokens: 0,
1156            cache_write_tokens: 0,
1157        });
1158
1159        assert_eq!(account.total_input_tokens, 300);
1160        assert_eq!(account.total_output_tokens, 125);
1161        assert_eq!(account.total_cache_read_tokens, 10);
1162        assert_eq!(account.total_cache_write_tokens, 20);
1163        assert_eq!(account.request_count, 2);
1164    }
1165
1166    #[test]
1167    fn provider_error_display() {
1168        let err = ProviderError::Timeout { elapsed_ms: 5000 };
1169        assert!(err.to_string().contains("5000"));
1170
1171        let err = ProviderError::RetryExhausted {
1172            attempts: 3,
1173            last_error: "timeout".to_string(),
1174        };
1175        assert!(err.to_string().contains("3"));
1176    }
1177
1178    #[cfg(feature = "anthropic")]
1179    mod anthropic_tests {
1180        use crate::anthropic_provider::normalize_anthropic_endpoint;
1181
1182        #[test]
1183        fn test_normalize_anthropic_endpoint_base_url() {
1184            assert_eq!(
1185                normalize_anthropic_endpoint("http://localhost:11434"),
1186                "http://localhost:11434/v1/messages"
1187            );
1188        }
1189
1190        #[test]
1191        fn test_normalize_anthropic_endpoint_trailing_slash() {
1192            assert_eq!(
1193                normalize_anthropic_endpoint("http://localhost:11434/"),
1194                "http://localhost:11434/v1/messages"
1195            );
1196        }
1197
1198        #[test]
1199        fn test_normalize_anthropic_endpoint_already_has_v1() {
1200            assert_eq!(
1201                normalize_anthropic_endpoint("http://localhost:11434/v1/messages"),
1202                "http://localhost:11434/v1/messages"
1203            );
1204        }
1205
1206        #[test]
1207        fn test_normalize_anthropic_endpoint_custom_path() {
1208            assert_eq!(
1209                normalize_anthropic_endpoint("https://api.anthropic.com"),
1210                "https://api.anthropic.com/v1/messages"
1211            );
1212        }
1213    }
1214}