Skip to main content

clawedcode_api/
lib.rs

1use clawedcode_mcp::McpServerConfig;
2use clawedcode_tools::ToolSpec;
3use futures_core::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::BTreeMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::time::Duration;
9use tokio::sync::mpsc;
10
11// --- Request / Response types ---
12
13#[derive(Debug, Clone)]
14pub struct CompletionRequest {
15    pub model: String,
16    pub prompt_pack: String,
17    pub system_prompt_name: String,
18    pub system_prompt_body: String,
19    pub prompt: String,
20    /// Structured conversation history. Providers may prefer this over `prompt`.
21    pub messages: Vec<ProviderMessage>,
22    pub tools: Vec<ToolSpec>,
23    pub skill_count: usize,
24    pub mcp_servers: BTreeMap<String, McpServerConfig>,
25}
26
27#[derive(Debug, Clone, Serialize)]
28pub struct CompletionResponse {
29    pub system_prompt: String,
30    pub response: String,
31    pub tool_count: usize,
32    pub skill_count: usize,
33    pub mcp_server_count: usize,
34}
35
36// --- Conversation message model (provider-facing) ---
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum ProviderRole {
41    User,
42    Assistant,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46pub struct ProviderMessage {
47    pub role: ProviderRole,
48    pub content: Vec<ProviderContentBlock>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum ProviderContentBlock {
54    Text {
55        text: String,
56    },
57    ToolUse {
58        id: String,
59        name: String,
60        #[serde(default)]
61        input: serde_json::Value,
62    },
63    ToolResult {
64        tool_use_id: String,
65        content: String,
66        #[serde(default)]
67        is_error: bool,
68    },
69    Thinking {
70        thinking: String,
71    },
72}
73
74// --- Streaming event model ---
75
76#[derive(Debug, Clone, Serialize)]
77pub struct ToolUseEvent {
78    pub id: String,
79    pub name: String,
80    pub input: String,
81}
82
83#[derive(Debug, Clone, Serialize)]
84pub struct ToolResultEvent {
85    pub tool_use_id: String,
86    pub content: String,
87    pub is_error: bool,
88}
89
90#[derive(Debug, Clone, Serialize)]
91pub struct UsageEvent {
92    pub input_tokens: u64,
93    pub output_tokens: u64,
94    pub cache_read_tokens: u64,
95    pub cache_write_tokens: u64,
96}
97
98#[derive(Debug, Clone, Serialize)]
99pub enum ApiEvent {
100    MessageDelta { text: String },
101    ThinkingDelta { text: String },
102    ToolUse { tool_use: ToolUseEvent },
103    ToolResult { tool_result: ToolResultEvent },
104    Usage { usage: UsageEvent },
105    Completed,
106}
107
108// --- Error types ---
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub enum ProviderError {
112    Network { message: String },
113    Api { status: u16, message: String },
114    Parse { message: String },
115    Timeout { elapsed_ms: u64 },
116    RetryExhausted { attempts: u32, last_error: String },
117    Other { message: String },
118}
119
120impl std::fmt::Display for ProviderError {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        match self {
123            ProviderError::Network { message } => write!(f, "Network error: {message}"),
124            ProviderError::Api { status, message } => write!(f, "API error ({status}): {message}"),
125            ProviderError::Parse { message } => write!(f, "Parse error: {message}"),
126            ProviderError::Timeout { elapsed_ms } => write!(f, "Timeout after {elapsed_ms}ms"),
127            ProviderError::RetryExhausted {
128                attempts,
129                last_error,
130            } => {
131                write!(f, "Retry exhausted after {attempts} attempts: {last_error}")
132            }
133            ProviderError::Other { message } => write!(f, "{message}"),
134        }
135    }
136}
137
138impl std::error::Error for ProviderError {}
139
140// --- Usage accounting ---
141
142#[derive(Debug, Clone, Default, Serialize)]
143pub struct UsageAccount {
144    pub total_input_tokens: u64,
145    pub total_output_tokens: u64,
146    pub total_cache_read_tokens: u64,
147    pub total_cache_write_tokens: u64,
148    pub request_count: u64,
149}
150
151impl UsageAccount {
152    pub fn record(&mut self, usage: &UsageEvent) {
153        self.total_input_tokens += usage.input_tokens;
154        self.total_output_tokens += usage.output_tokens;
155        self.total_cache_read_tokens += usage.cache_read_tokens;
156        self.total_cache_write_tokens += usage.cache_write_tokens;
157        self.request_count += 1;
158    }
159}
160
161// --- Retry / Timeout envelope ---
162
163#[derive(Debug, Clone)]
164pub struct RetryConfig {
165    pub max_attempts: u32,
166    pub base_delay: Duration,
167    pub max_delay: Duration,
168}
169
170impl Default for RetryConfig {
171    fn default() -> Self {
172        Self {
173            max_attempts: 3,
174            base_delay: Duration::from_millis(200),
175            max_delay: Duration::from_secs(5),
176        }
177    }
178}
179
180#[derive(Debug, Clone)]
181pub struct TimeoutConfig {
182    pub per_request: Duration,
183}
184
185impl Default for TimeoutConfig {
186    fn default() -> Self {
187        Self {
188            per_request: Duration::from_secs(60),
189        }
190    }
191}
192
193// --- Provider trait (async, streaming) ---
194
195pub type EventStream = Pin<Box<dyn Stream<Item = Result<ApiEvent, ProviderError>> + Send>>;
196
197pub trait Provider: Send + Sync {
198    fn complete(
199        &self,
200        request: &CompletionRequest,
201    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>;
202
203    fn stream(&self, request: &CompletionRequest) -> EventStream;
204}
205
206// --- Boxed provider handle ---
207
208pub type BoxedProvider = Box<dyn Provider>;
209
210// --- Mock provider (streams with delays) ---
211
212#[derive(Debug, Default, Clone)]
213pub struct MockProvider;
214
215impl Provider for MockProvider {
216    fn complete(
217        &self,
218        request: &CompletionRequest,
219    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
220        let response = mock_complete_response(request);
221        Box::pin(async move { Ok(response) })
222    }
223
224    fn stream(&self, request: &CompletionRequest) -> EventStream {
225        let req = request.clone();
226        let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
227
228        tokio::spawn(async move {
229            let events = mock_stream_events(&req);
230            for (idx, event) in events.into_iter().enumerate() {
231                if idx > 0 {
232                    tokio::time::sleep(Duration::from_millis(18)).await;
233                }
234                if tx.send(Ok(event)).await.is_err() {
235                    return;
236                }
237            }
238        });
239
240        Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
241            rx.recv().await.map(|item| (item, rx))
242        }))
243    }
244}
245
246fn mock_complete_response(request: &CompletionRequest) -> CompletionResponse {
247    if wants_read_cargo_toml(&request.prompt) {
248        if let Some(tool_content) = first_tool_result_content(&request.messages) {
249            let response = mock_summarize_cargo_toml(&tool_content);
250            return CompletionResponse {
251                system_prompt: request.system_prompt_name.clone(),
252                response,
253                tool_count: request.tools.len(),
254                skill_count: request.skill_count,
255                mcp_server_count: request.mcp_servers.len(),
256            };
257        }
258
259        return CompletionResponse {
260            system_prompt: request.system_prompt_name.clone(),
261            response: "I'll read Cargo.toml first.".to_string(),
262            tool_count: request.tools.len(),
263            skill_count: request.skill_count,
264            mcp_server_count: request.mcp_servers.len(),
265        };
266    }
267
268    let response = mock_plain_reply(request);
269
270    CompletionResponse {
271        system_prompt: request.system_prompt_name.clone(),
272        response,
273        tool_count: request.tools.len(),
274        skill_count: request.skill_count,
275        mcp_server_count: request.mcp_servers.len(),
276    }
277}
278
279fn mock_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
280    if wants_read_cargo_toml(&request.prompt) {
281        if let Some(tool_content) = first_tool_result_content(&request.messages) {
282            return vec![
283                ApiEvent::ThinkingDelta {
284                    text: "I have the Cargo.toml contents; summarizing.".to_string(),
285                },
286                ApiEvent::MessageDelta {
287                    text: mock_summarize_cargo_toml(&tool_content),
288                },
289                ApiEvent::Usage {
290                    usage: UsageEvent {
291                        input_tokens: 220,
292                        output_tokens: 90,
293                        cache_read_tokens: 0,
294                        cache_write_tokens: 0,
295                    },
296                },
297                ApiEvent::Completed,
298            ];
299        }
300
301        return vec![
302            ApiEvent::ThinkingDelta {
303                text: "I should read Cargo.toml to answer this.".to_string(),
304            },
305            ApiEvent::ToolUse {
306                tool_use: ToolUseEvent {
307                    id: "tool_1".to_string(),
308                    name: "read_file".to_string(),
309                    input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
310                },
311            },
312            ApiEvent::Completed,
313        ];
314    }
315
316    vec![
317        ApiEvent::ThinkingDelta {
318            text: "Thinking about how to help...".to_string(),
319        },
320        ApiEvent::MessageDelta {
321            text: mock_plain_reply(request),
322        },
323        ApiEvent::Usage {
324            usage: UsageEvent {
325                input_tokens: 80,
326                output_tokens: 40,
327                cache_read_tokens: 0,
328                cache_write_tokens: 0,
329            },
330        },
331        ApiEvent::Completed,
332    ]
333}
334
335fn mock_plain_reply(request: &CompletionRequest) -> String {
336    let prompt = request.prompt.trim().to_ascii_lowercase();
337
338    if prompt.is_empty() {
339        return "Hello! How can I help you today?".to_string();
340    }
341
342    if ["hello", "hi", "hey"]
343        .iter()
344        .any(|greeting| prompt == *greeting)
345    {
346        return "Hello! How can I assist you today?".to_string();
347    }
348
349    if prompt.contains("how are you") {
350        return "I'm doing well, thank you! What can I help you with?".to_string();
351    }
352
353    "I received your message. To get started with real AI-powered assistance, configure a provider like Claude or Ollama.".to_string()
354}
355
356fn wants_read_cargo_toml(prompt: &str) -> bool {
357    let p = prompt.to_ascii_lowercase();
358    p.contains("cargo.toml") && (p.contains("read") || p.contains("summarize"))
359}
360
361fn first_tool_result_content(messages: &[ProviderMessage]) -> Option<String> {
362    for m in messages {
363        for b in &m.content {
364            if let ProviderContentBlock::ToolResult { content, .. } = b {
365                return Some(content.clone());
366            }
367        }
368    }
369    None
370}
371
372fn mock_summarize_cargo_toml(contents: &str) -> String {
373    if contents.contains("[workspace]") {
374        let mut out = String::from("Cargo.toml defines a Rust workspace.\n");
375        if contents.contains("members") {
376            out.push_str("It declares workspace members; this repo is a multi-crate workspace.\n");
377        }
378        out.push_str("Key crates include: clawedcode (cli), clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, clawedcode-tui.");
379        out
380    } else {
381        "Cargo.toml does not look like a workspace manifest (no [workspace] section).".to_string()
382    }
383}
384
385// --- MockToolProvider: deterministic provider for tests that triggers a tool call ---
386
387/// A mock provider that, on the first call, emits a `read_file` ToolUse for
388/// `Cargo.toml`. On subsequent calls (i.e. after the runtime has appended a
389/// tool_result message), it returns a final text response that references the
390/// tool result.
391///
392/// Detection of "subsequent call" is done by checking whether the session
393/// already contains a `tool` role message (injected by the runtime after tool
394/// execution).
395#[derive(Debug, Default, Clone)]
396pub struct MockToolProvider;
397
398impl Provider for MockToolProvider {
399    fn complete(
400        &self,
401        request: &CompletionRequest,
402    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
403        let response = mock_tool_complete_response(request);
404        Box::pin(async move { Ok(response) })
405    }
406
407    fn stream(&self, request: &CompletionRequest) -> EventStream {
408        let req = request.clone();
409        let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
410
411        tokio::spawn(async move {
412            let events = mock_tool_stream_events(&req);
413            for (idx, event) in events.into_iter().enumerate() {
414                if idx > 0 {
415                    tokio::time::sleep(Duration::from_millis(5)).await;
416                }
417                if tx.send(Ok(event)).await.is_err() {
418                    return;
419                }
420            }
421        });
422
423        Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
424            rx.recv().await.map(|item| (item, rx))
425        }))
426    }
427}
428
429fn mock_tool_complete_response(request: &CompletionRequest) -> CompletionResponse {
430    if has_tool_result_message(request) {
431        let response = "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string();
432        return CompletionResponse {
433            system_prompt: request.system_prompt_name.clone(),
434            response,
435            tool_count: request.tools.len(),
436            skill_count: request.skill_count,
437            mcp_server_count: request.mcp_servers.len(),
438        };
439    }
440
441    CompletionResponse {
442        system_prompt: request.system_prompt_name.clone(),
443        response: "I'll read the Cargo.toml file.".to_string(),
444        tool_count: request.tools.len(),
445        skill_count: request.skill_count,
446        mcp_server_count: request.mcp_servers.len(),
447    }
448}
449
450fn mock_tool_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
451    if has_tool_result_message(request) {
452        return vec![
453            ApiEvent::ThinkingDelta {
454                text: "I have the file contents now.".to_string(),
455            },
456            ApiEvent::MessageDelta {
457                text: "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string(),
458            },
459            ApiEvent::Usage {
460                usage: UsageEvent {
461                    input_tokens: 200,
462                    output_tokens: 60,
463                    cache_read_tokens: 0,
464                    cache_write_tokens: 0,
465                },
466            },
467            ApiEvent::Completed,
468        ];
469    }
470
471    vec![
472        ApiEvent::ThinkingDelta {
473            text: "I should read the Cargo.toml file.".to_string(),
474        },
475        ApiEvent::ToolUse {
476            tool_use: ToolUseEvent {
477                id: "tool_1".to_string(),
478                name: "read_file".to_string(),
479                input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
480            },
481        },
482        ApiEvent::Usage {
483            usage: UsageEvent {
484                input_tokens: 100,
485                output_tokens: 30,
486                cache_read_tokens: 0,
487                cache_write_tokens: 0,
488            },
489        },
490        ApiEvent::Completed,
491    ]
492}
493
494/// Returns true if the conversation already contains tool result blocks,
495/// indicating this is a re-query after tool execution.
496fn has_tool_result_message(request: &CompletionRequest) -> bool {
497    request.messages.iter().any(|m| {
498        m.content
499            .iter()
500            .any(|b| matches!(b, ProviderContentBlock::ToolResult { .. }))
501    })
502}
503
504// --- Optional Anthropic provider (behind feature flag + env var) ---
505
506#[cfg(feature = "anthropic")]
507pub mod anthropic_provider {
508    use super::*;
509    use reqwest::Client;
510
511    #[derive(Debug, Clone)]
512    pub struct AnthropicProvider {
513        client: Client,
514        api_key: String,
515        endpoint: String,
516        anthropic_version: String,
517        retry: RetryConfig,
518        timeout: TimeoutConfig,
519    }
520
521    pub(crate) fn normalize_anthropic_endpoint(endpoint: &str) -> String {
522        if endpoint.contains("/v1/messages") {
523            endpoint.to_string()
524        } else {
525            let endpoint = endpoint.trim_end_matches('/');
526            format!("{}/v1/messages", endpoint)
527        }
528    }
529
530    impl AnthropicProvider {
531        pub fn from_env() -> Option<Self> {
532            let api_key = std::env::var("ANTHROPIC_API_KEY")
533                .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
534                .ok()?;
535            let endpoint = std::env::var("CLAWEDCODE_ANTHROPIC_ENDPOINT")
536                .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
537                .unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string());
538            let endpoint = normalize_anthropic_endpoint(&endpoint);
539            let anthropic_version = std::env::var("CLAWEDCODE_ANTHROPIC_VERSION")
540                .unwrap_or_else(|_| "2023-06-01".to_string());
541            Some(Self {
542                client: Client::new(),
543                api_key,
544                endpoint,
545                anthropic_version,
546                retry: RetryConfig::default(),
547                timeout: TimeoutConfig::default(),
548            })
549        }
550
551        pub fn new(api_key: String, endpoint: String) -> Self {
552            Self {
553                client: Client::new(),
554                api_key,
555                endpoint,
556                anthropic_version: "2023-06-01".to_string(),
557                retry: RetryConfig::default(),
558                timeout: TimeoutConfig::default(),
559            }
560        }
561    }
562
563    fn build_anthropic_tools(tools: &[ToolSpec]) -> serde_json::Value {
564        if tools.is_empty() {
565            return serde_json::Value::Null;
566        }
567        serde_json::Value::Array(
568            tools
569                .iter()
570                .map(|t| {
571                    serde_json::json!({
572                        "name": t.name,
573                        "description": t.description,
574                        "input_schema": t.input_schema,
575                    })
576                })
577                .collect(),
578        )
579    }
580
581    fn sanitize_messages_for_anthropic(messages: &[ProviderMessage]) -> Vec<ProviderMessage> {
582        messages
583            .iter()
584            .map(|m| ProviderMessage {
585                role: m.role.clone(),
586                content: m
587                    .content
588                    .iter()
589                    .filter(|b| !matches!(b, ProviderContentBlock::Thinking { .. }))
590                    .cloned()
591                    .collect(),
592            })
593            .filter(|m| !m.content.is_empty())
594            .collect()
595    }
596
597    pub(crate) fn initial_tool_input_buffer(input: &serde_json::Value) -> String {
598        match input {
599            serde_json::Value::Null => String::new(),
600            serde_json::Value::Object(map) if map.is_empty() => String::new(),
601            serde_json::Value::String(text) if text.trim().is_empty() => String::new(),
602            other => serde_json::to_string(other).unwrap_or_default(),
603        }
604    }
605
606    pub(crate) fn finalize_tool_input_buffer(input: String) -> String {
607        if input.trim().is_empty() {
608            "{}".to_string()
609        } else {
610            input
611        }
612    }
613
614    impl Provider for AnthropicProvider {
615        fn complete(
616            &self,
617            request: &CompletionRequest,
618        ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>
619        {
620            let req = request.clone();
621            let api_key = self.api_key.clone();
622            let client = self.client.clone();
623            let endpoint = self.endpoint.clone();
624            let anthropic_version = self.anthropic_version.clone();
625            let retry = self.retry.clone();
626            let timeout = self.timeout.clone();
627            Box::pin(async move {
628                let tools = build_anthropic_tools(&req.tools);
629                let messages = sanitize_messages_for_anthropic(&req.messages);
630                let messages = if messages.is_empty() {
631                    serde_json::json!([{"role": "user", "content": req.prompt}])
632                } else {
633                    serde_json::to_value(messages).unwrap_or_else(
634                        |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
635                    )
636                };
637                let mut body = serde_json::json!({
638                    "model": req.model,
639                    "max_tokens": 4096,
640                    "system": req.system_prompt_body,
641                    "messages": messages,
642                });
643                if !tools.is_null() {
644                    body["tools"] = tools;
645                }
646
647                let mut last_err: Option<ProviderError> = None;
648
649                for attempt in 1..=retry.max_attempts {
650                    let send_fut = client
651                        .post(&endpoint)
652                        .header("x-api-key", &api_key)
653                        .header("anthropic-version", &anthropic_version)
654                        .header("content-type", "application/json")
655                        .json(&body)
656                        .send();
657
658                    let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
659                        Ok(Ok(r)) => r,
660                        Ok(Err(e)) => {
661                            last_err = Some(ProviderError::Network {
662                                message: e.to_string(),
663                            });
664                            if attempt < retry.max_attempts {
665                                let backoff_ms = (retry.base_delay.as_millis() as u64)
666                                    .saturating_mul(1u64 << (attempt - 1));
667                                tokio::time::sleep(Duration::from_millis(
668                                    backoff_ms.min(retry.max_delay.as_millis() as u64),
669                                ))
670                                .await;
671                                continue;
672                            }
673                            break;
674                        }
675                        Err(_) => {
676                            last_err = Some(ProviderError::Timeout {
677                                elapsed_ms: timeout.per_request.as_millis() as u64,
678                            });
679                            if attempt < retry.max_attempts {
680                                tokio::time::sleep(retry.base_delay).await;
681                                continue;
682                            }
683                            break;
684                        }
685                    };
686
687                    let status = resp.status().as_u16();
688                    if !resp.status().is_success() {
689                        let text = resp.text().await.unwrap_or_default();
690                        let err = ProviderError::Api {
691                            status,
692                            message: text,
693                        };
694                        last_err = Some(err);
695
696                        // Retry 5xx; fail fast otherwise.
697                        let retryable = (500..=599).contains(&status);
698                        if retryable && attempt < retry.max_attempts {
699                            tokio::time::sleep(retry.base_delay).await;
700                            continue;
701                        }
702                        break;
703                    }
704
705                    let json: serde_json::Value =
706                        resp.json().await.map_err(|e| ProviderError::Parse {
707                            message: e.to_string(),
708                        })?;
709
710                    let response = json["content"]
711                        .as_array()
712                        .map(|arr| {
713                            arr.iter()
714                                .filter_map(|block| block["text"].as_str())
715                                .collect::<Vec<_>>()
716                                .join("")
717                        })
718                        .unwrap_or_default();
719
720                    return Ok(CompletionResponse {
721                        system_prompt: req.system_prompt_name.clone(),
722                        response,
723                        tool_count: req.tools.len(),
724                        skill_count: req.skill_count,
725                        mcp_server_count: req.mcp_servers.len(),
726                    });
727                }
728
729                Err(match last_err {
730                    Some(e) => ProviderError::RetryExhausted {
731                        attempts: retry.max_attempts,
732                        last_error: e.to_string(),
733                    },
734                    None => ProviderError::Other {
735                        message: "request failed".to_string(),
736                    },
737                })
738            })
739        }
740
741        fn stream(&self, request: &CompletionRequest) -> EventStream {
742            let req = request.clone();
743            let api_key = self.api_key.clone();
744            let client = self.client.clone();
745            let endpoint = self.endpoint.clone();
746            let anthropic_version = self.anthropic_version.clone();
747            let timeout = self.timeout.clone();
748
749            let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(64);
750
751            tokio::spawn(async move {
752                let tools = build_anthropic_tools(&req.tools);
753                let messages = sanitize_messages_for_anthropic(&req.messages);
754                let messages = if messages.is_empty() {
755                    serde_json::json!([{"role": "user", "content": req.prompt}])
756                } else {
757                    serde_json::to_value(messages).unwrap_or_else(
758                        |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
759                    )
760                };
761                let mut body = serde_json::json!({
762                    "model": req.model,
763                    "max_tokens": 4096,
764                    "system": req.system_prompt_body,
765                    "messages": messages,
766                    "stream": true,
767                });
768                if !tools.is_null() {
769                    body["tools"] = tools;
770                }
771
772                let send_fut = client
773                    .post(&endpoint)
774                    .header("x-api-key", &api_key)
775                    .header("anthropic-version", &anthropic_version)
776                    .header("content-type", "application/json")
777                    .json(&body)
778                    .send();
779
780                let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
781                    Ok(Ok(r)) => r,
782                    Ok(Err(e)) => {
783                        let _ = tx
784                            .send(Err(ProviderError::Network {
785                                message: e.to_string(),
786                            }))
787                            .await;
788                        return;
789                    }
790                    Err(_) => {
791                        let _ = tx
792                            .send(Err(ProviderError::Timeout {
793                                elapsed_ms: timeout.per_request.as_millis() as u64,
794                            }))
795                            .await;
796                        return;
797                    }
798                };
799
800                if !resp.status().is_success() {
801                    let status = resp.status().as_u16();
802                    let text = resp.text().await.unwrap_or_default();
803                    let _ = tx
804                        .send(Err(ProviderError::Api {
805                            status,
806                            message: text,
807                        }))
808                        .await;
809                    return;
810                }
811
812                let mut buf = String::new();
813                let mut bytes = resp.bytes_stream();
814                use futures_util::StreamExt;
815
816                let mut current_tool_use: Option<(String, String, String)> = None;
817
818                while let Some(chunk) = bytes.next().await {
819                    let chunk = match chunk {
820                        Ok(c) => c,
821                        Err(e) => {
822                            let _ = tx
823                                .send(Err(ProviderError::Network {
824                                    message: e.to_string(),
825                                }))
826                                .await;
827                            return;
828                        }
829                    };
830
831                    buf.push_str(&String::from_utf8_lossy(&chunk));
832
833                    // SSE frames are separated by a blank line.
834                    while let Some(idx) = buf.find("\n\n") {
835                        let frame: String = buf.drain(..(idx + 2)).collect();
836                        let mut data_lines = Vec::new();
837                        for line in frame.lines() {
838                            let line = line.trim();
839                            if let Some(rest) = line.strip_prefix("data:") {
840                                let payload = rest.trim();
841                                if !payload.is_empty() {
842                                    data_lines.push(payload.to_string());
843                                }
844                            }
845                        }
846
847                        if data_lines.is_empty() {
848                            continue;
849                        }
850
851                        let data = data_lines.join("\n");
852                        if data == "[DONE]" {
853                            continue;
854                        }
855
856                        let Ok(event) = serde_json::from_str::<serde_json::Value>(&data) else {
857                            continue;
858                        };
859                        let typ = event["type"].as_str().unwrap_or("");
860                        match typ {
861                            "content_block_start" => {
862                                let cb_type = event["content_block"]["type"].as_str().unwrap_or("");
863                                if cb_type == "text" {
864                                    if let Some(t) = event["content_block"]["text"].as_str() {
865                                        if !t.is_empty() {
866                                            let _ = tx
867                                                .send(Ok(ApiEvent::MessageDelta {
868                                                    text: t.to_string(),
869                                                }))
870                                                .await;
871                                        }
872                                    }
873                                } else if cb_type == "tool_use" {
874                                    let id = event["content_block"]["id"]
875                                        .as_str()
876                                        .unwrap_or("")
877                                        .to_string();
878                                    let name = event["content_block"]["name"]
879                                        .as_str()
880                                        .unwrap_or("")
881                                        .to_string();
882                                    let input =
883                                        initial_tool_input_buffer(&event["content_block"]["input"]);
884                                    let input_str = input;
885                                    current_tool_use = Some((id, name, input_str));
886                                } else if cb_type == "thinking" {
887                                    if let Some(t) = event["content_block"]["thinking"].as_str() {
888                                        if !t.is_empty() {
889                                            let _ = tx
890                                                .send(Ok(ApiEvent::ThinkingDelta {
891                                                    text: t.to_string(),
892                                                }))
893                                                .await;
894                                        }
895                                    }
896                                }
897                            }
898                            "content_block_delta" => {
899                                let delta_type = event["delta"]["type"].as_str().unwrap_or("");
900                                match delta_type {
901                                    "text_delta" => {
902                                        if let Some(t) = event["delta"]["text"].as_str() {
903                                            let _ = tx
904                                                .send(Ok(ApiEvent::MessageDelta {
905                                                    text: t.to_string(),
906                                                }))
907                                                .await;
908                                        }
909                                    }
910                                    "thinking_delta" => {
911                                        if let Some(t) = event["delta"]["thinking"].as_str() {
912                                            let _ = tx
913                                                .send(Ok(ApiEvent::ThinkingDelta {
914                                                    text: t.to_string(),
915                                                }))
916                                                .await;
917                                        }
918                                    }
919                                    "input_json_delta" => {
920                                        if let Some(partial) =
921                                            event["delta"]["partial_json"].as_str()
922                                        {
923                                            if let Some((_id, _name, input_buf)) =
924                                                current_tool_use.as_mut()
925                                            {
926                                                input_buf.push_str(partial);
927                                            }
928                                        }
929                                    }
930                                    _ => {}
931                                }
932                            }
933                            "content_block_stop" => {
934                                if let Some((id, name, input)) = current_tool_use.take() {
935                                    let input = finalize_tool_input_buffer(input);
936                                    let _ = tx
937                                        .send(Ok(ApiEvent::ToolUse {
938                                            tool_use: ToolUseEvent { id, name, input },
939                                        }))
940                                        .await;
941                                }
942                            }
943                            "message_delta" => {
944                                if let Some(usage) = event.get("usage") {
945                                    let _ = tx
946                                        .send(Ok(ApiEvent::Usage {
947                                            usage: UsageEvent {
948                                                input_tokens: usage["input_tokens"]
949                                                    .as_u64()
950                                                    .unwrap_or(0),
951                                                output_tokens: usage["output_tokens"]
952                                                    .as_u64()
953                                                    .unwrap_or(0),
954                                                cache_read_tokens: usage["cache_read_input_tokens"]
955                                                    .as_u64()
956                                                    .unwrap_or(0),
957                                                cache_write_tokens:
958                                                    usage["cache_creation_input_tokens"]
959                                                        .as_u64()
960                                                        .unwrap_or(0),
961                                            },
962                                        }))
963                                        .await;
964                                }
965                            }
966                            "message_stop" => {
967                                let _ = tx.send(Ok(ApiEvent::Completed)).await;
968                                return;
969                            }
970                            _ => {}
971                        }
972                    }
973                }
974
975                let _ = tx.send(Ok(ApiEvent::Completed)).await;
976            });
977
978            Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
979                rx.recv().await.map(|item| (item, rx))
980            }))
981        }
982    }
983}
984
985// --- Provider factory (picks provider from env) ---
986
987pub fn create_provider() -> BoxedProvider {
988    let provider_name = std::env::var("CLAWEDCODE_PROVIDER").unwrap_or_default();
989
990    match provider_name.as_str() {
991        #[cfg(feature = "anthropic")]
992        "anthropic" => {
993            if let Some(p) = anthropic_provider::AnthropicProvider::from_env() {
994                tracing::info!("Using Anthropic provider");
995                return Box::new(p);
996            }
997            tracing::warn!("ANTHROPIC_API_KEY not set, falling back to mock provider");
998        }
999        #[cfg(not(feature = "anthropic"))]
1000        "anthropic" => {
1001            tracing::warn!(
1002                "Anthropic provider requested but 'anthropic' feature not enabled, falling back to mock"
1003            );
1004        }
1005        _ => {}
1006    }
1007
1008    tracing::info!("Using mock provider");
1009    Box::new(MockProvider)
1010}
1011
1012// --- Helper: collect stream into response ---
1013
1014pub async fn collect_stream_to_response(
1015    stream: EventStream,
1016    request: &CompletionRequest,
1017) -> Result<CompletionResponse, ProviderError> {
1018    use futures_util::StreamExt;
1019    let mut text = String::new();
1020    let mut thinking = String::new();
1021
1022    let mut s = stream;
1023    while let Some(event) = s.next().await {
1024        match event? {
1025            ApiEvent::MessageDelta { text: t } => text.push_str(&t),
1026            ApiEvent::ThinkingDelta { text: t } => thinking.push_str(&t),
1027            ApiEvent::Usage { usage: _ } => {}
1028            ApiEvent::Completed => break,
1029            ApiEvent::ToolUse { .. } | ApiEvent::ToolResult { .. } => {}
1030        }
1031    }
1032
1033    Ok(CompletionResponse {
1034        system_prompt: request.system_prompt_name.clone(),
1035        response: text,
1036        tool_count: request.tools.len(),
1037        skill_count: request.skill_count,
1038        mcp_server_count: request.mcp_servers.len(),
1039    })
1040}
1041
1042#[cfg(test)]
1043mod tests {
1044    use super::*;
1045    use futures_util::StreamExt;
1046
1047    #[tokio::test]
1048    async fn mock_stream_has_multiple_deltas() {
1049        let provider = MockProvider;
1050        let request = CompletionRequest {
1051            model: "test-model".to_string(),
1052            prompt_pack: "default".to_string(),
1053            system_prompt_name: "default".to_string(),
1054            system_prompt_body: "You are helpful.".to_string(),
1055            prompt: "hello".to_string(),
1056            messages: vec![],
1057            tools: vec![],
1058            skill_count: 0,
1059            mcp_servers: BTreeMap::new(),
1060        };
1061
1062        let events: Vec<_> = provider
1063            .stream(&request)
1064            .filter_map(|e| async move { e.ok() })
1065            .collect()
1066            .await;
1067        assert!(!events.is_empty());
1068
1069        assert!(matches!(events.last(), Some(ApiEvent::Completed)));
1070    }
1071
1072    #[tokio::test]
1073    async fn mock_stream_event_ordering() {
1074        let provider = MockProvider;
1075        let request = CompletionRequest {
1076            model: "test-model".to_string(),
1077            prompt_pack: "default".to_string(),
1078            system_prompt_name: "default".to_string(),
1079            system_prompt_body: "You are helpful.".to_string(),
1080            prompt: "hello".to_string(),
1081            messages: vec![],
1082            tools: vec![],
1083            skill_count: 0,
1084            mcp_servers: BTreeMap::new(),
1085        };
1086
1087        let events: Vec<_> = provider
1088            .stream(&request)
1089            .filter_map(|e| async move { e.ok() })
1090            .collect()
1091            .await;
1092
1093        assert!(matches!(events[0], ApiEvent::ThinkingDelta { .. }));
1094        assert!(matches!(events[1], ApiEvent::MessageDelta { .. }));
1095
1096        let usage_idx = events
1097            .iter()
1098            .position(|e| matches!(e, ApiEvent::Usage { .. }))
1099            .expect("Usage event should exist");
1100        let completed_idx = events
1101            .iter()
1102            .position(|e| matches!(e, ApiEvent::Completed))
1103            .expect("Completed event should exist");
1104        assert!(usage_idx < completed_idx);
1105    }
1106
1107    #[tokio::test]
1108    async fn mock_stream_concatenated_text_matches_complete_response() {
1109        let provider = MockProvider;
1110        let request = CompletionRequest {
1111            model: "test-model".to_string(),
1112            prompt_pack: "default".to_string(),
1113            system_prompt_name: "default".to_string(),
1114            system_prompt_body: "You are helpful.".to_string(),
1115            prompt: "hello".to_string(),
1116            messages: vec![],
1117            tools: vec![],
1118            skill_count: 0,
1119            mcp_servers: BTreeMap::new(),
1120        };
1121
1122        let direct = provider.complete(&request).await.unwrap();
1123        let events: Vec<_> = provider
1124            .stream(&request)
1125            .filter_map(|e| async move { e.ok() })
1126            .collect()
1127            .await;
1128
1129        let mut text = String::new();
1130        for e in &events {
1131            if let ApiEvent::MessageDelta { text: t } = e {
1132                text.push_str(t);
1133            }
1134        }
1135
1136        assert_eq!(text, direct.response);
1137    }
1138
1139    #[test]
1140    fn usage_account_accumulates() {
1141        let mut account = UsageAccount::default();
1142        account.record(&UsageEvent {
1143            input_tokens: 100,
1144            output_tokens: 50,
1145            cache_read_tokens: 10,
1146            cache_write_tokens: 20,
1147        });
1148        account.record(&UsageEvent {
1149            input_tokens: 200,
1150            output_tokens: 75,
1151            cache_read_tokens: 0,
1152            cache_write_tokens: 0,
1153        });
1154
1155        assert_eq!(account.total_input_tokens, 300);
1156        assert_eq!(account.total_output_tokens, 125);
1157        assert_eq!(account.total_cache_read_tokens, 10);
1158        assert_eq!(account.total_cache_write_tokens, 20);
1159        assert_eq!(account.request_count, 2);
1160    }
1161
1162    #[test]
1163    fn provider_error_display() {
1164        let err = ProviderError::Timeout { elapsed_ms: 5000 };
1165        assert!(err.to_string().contains("5000"));
1166
1167        let err = ProviderError::RetryExhausted {
1168            attempts: 3,
1169            last_error: "timeout".to_string(),
1170        };
1171        assert!(err.to_string().contains("3"));
1172    }
1173
1174    #[cfg(feature = "anthropic")]
1175    mod anthropic_tests {
1176        use crate::anthropic_provider::{
1177            finalize_tool_input_buffer, initial_tool_input_buffer, normalize_anthropic_endpoint,
1178        };
1179
1180        #[test]
1181        fn test_normalize_anthropic_endpoint_base_url() {
1182            assert_eq!(
1183                normalize_anthropic_endpoint("http://localhost:11434"),
1184                "http://localhost:11434/v1/messages"
1185            );
1186        }
1187
1188        #[test]
1189        fn test_normalize_anthropic_endpoint_trailing_slash() {
1190            assert_eq!(
1191                normalize_anthropic_endpoint("http://localhost:11434/"),
1192                "http://localhost:11434/v1/messages"
1193            );
1194        }
1195
1196        #[test]
1197        fn test_normalize_anthropic_endpoint_already_has_v1() {
1198            assert_eq!(
1199                normalize_anthropic_endpoint("http://localhost:11434/v1/messages"),
1200                "http://localhost:11434/v1/messages"
1201            );
1202        }
1203
1204        #[test]
1205        fn test_normalize_anthropic_endpoint_custom_path() {
1206            assert_eq!(
1207                normalize_anthropic_endpoint("https://api.anthropic.com"),
1208                "https://api.anthropic.com/v1/messages"
1209            );
1210        }
1211
1212        #[test]
1213        fn initial_tool_input_buffer_drops_empty_object() {
1214            assert_eq!(
1215                initial_tool_input_buffer(&serde_json::json!({})),
1216                String::new()
1217            );
1218        }
1219
1220        #[test]
1221        fn initial_tool_input_buffer_keeps_non_empty_object() {
1222            assert_eq!(
1223                initial_tool_input_buffer(&serde_json::json!({"command": "ls -la"})),
1224                r#"{"command":"ls -la"}"#
1225            );
1226        }
1227
1228        #[test]
1229        fn finalize_tool_input_buffer_defaults_empty_to_object() {
1230            assert_eq!(finalize_tool_input_buffer(String::new()), "{}");
1231        }
1232    }
1233}