Skip to main content

clawedcode_api/
lib.rs

1use clawedcode_mcp::McpServerConfig;
2use clawedcode_tools::ToolSpec;
3use futures_core::Stream;
4use serde::{Deserialize, Serialize};
5use std::collections::BTreeMap;
6use std::future::Future;
7use std::pin::Pin;
8use std::time::Duration;
9use tokio::sync::mpsc;
10
11// --- Request / Response types ---
12
13#[derive(Debug, Clone)]
14pub struct CompletionRequest {
15    pub model: String,
16    pub prompt_pack: String,
17    pub system_prompt_name: String,
18    pub system_prompt_body: String,
19    pub prompt: String,
20    /// Structured conversation history. Providers may prefer this over `prompt`.
21    pub messages: Vec<ProviderMessage>,
22    pub tools: Vec<ToolSpec>,
23    pub skill_count: usize,
24    pub mcp_servers: BTreeMap<String, McpServerConfig>,
25}
26
27#[derive(Debug, Clone, Serialize)]
28pub struct CompletionResponse {
29    pub system_prompt: String,
30    pub response: String,
31    pub tool_count: usize,
32    pub skill_count: usize,
33    pub mcp_server_count: usize,
34}
35
36// --- Conversation message model (provider-facing) ---
37
38#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
39#[serde(rename_all = "snake_case")]
40pub enum ProviderRole {
41    User,
42    Assistant,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46pub struct ProviderMessage {
47    pub role: ProviderRole,
48    pub content: Vec<ProviderContentBlock>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum ProviderContentBlock {
54    Text {
55        text: String,
56    },
57    ToolUse {
58        id: String,
59        name: String,
60        #[serde(default)]
61        input: serde_json::Value,
62    },
63    ToolResult {
64        tool_use_id: String,
65        content: String,
66        #[serde(default)]
67        is_error: bool,
68    },
69    Thinking {
70        thinking: String,
71    },
72}
73
74// --- Streaming event model ---
75
76#[derive(Debug, Clone, Serialize)]
77pub struct ToolUseEvent {
78    pub id: String,
79    pub name: String,
80    pub input: String,
81}
82
83#[derive(Debug, Clone, Serialize)]
84pub struct ToolResultEvent {
85    pub tool_use_id: String,
86    pub content: String,
87    pub is_error: bool,
88}
89
90#[derive(Debug, Clone, Serialize)]
91pub struct UsageEvent {
92    pub input_tokens: u64,
93    pub output_tokens: u64,
94    pub cache_read_tokens: u64,
95    pub cache_write_tokens: u64,
96}
97
98#[derive(Debug, Clone, Serialize)]
99pub enum ApiEvent {
100    MessageDelta { text: String },
101    ThinkingDelta { text: String },
102    ToolUse { tool_use: ToolUseEvent },
103    ToolResult { tool_result: ToolResultEvent },
104    Usage { usage: UsageEvent },
105    Completed,
106}
107
108// --- Error types ---
109
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub enum ProviderError {
112    Network { message: String },
113    Api { status: u16, message: String },
114    Parse { message: String },
115    Timeout { elapsed_ms: u64 },
116    RetryExhausted { attempts: u32, last_error: String },
117    Other { message: String },
118}
119
120impl std::fmt::Display for ProviderError {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        match self {
123            ProviderError::Network { message } => write!(f, "Network error: {message}"),
124            ProviderError::Api { status, message } => write!(f, "API error ({status}): {message}"),
125            ProviderError::Parse { message } => write!(f, "Parse error: {message}"),
126            ProviderError::Timeout { elapsed_ms } => write!(f, "Timeout after {elapsed_ms}ms"),
127            ProviderError::RetryExhausted {
128                attempts,
129                last_error,
130            } => {
131                write!(f, "Retry exhausted after {attempts} attempts: {last_error}")
132            }
133            ProviderError::Other { message } => write!(f, "{message}"),
134        }
135    }
136}
137
138impl std::error::Error for ProviderError {}
139
140// --- Usage accounting ---
141
142#[derive(Debug, Clone, Default, Serialize)]
143pub struct UsageAccount {
144    pub total_input_tokens: u64,
145    pub total_output_tokens: u64,
146    pub total_cache_read_tokens: u64,
147    pub total_cache_write_tokens: u64,
148    pub request_count: u64,
149}
150
151impl UsageAccount {
152    pub fn record(&mut self, usage: &UsageEvent) {
153        self.total_input_tokens += usage.input_tokens;
154        self.total_output_tokens += usage.output_tokens;
155        self.total_cache_read_tokens += usage.cache_read_tokens;
156        self.total_cache_write_tokens += usage.cache_write_tokens;
157        self.request_count += 1;
158    }
159}
160
161// --- Retry / Timeout envelope ---
162
163#[derive(Debug, Clone)]
164pub struct RetryConfig {
165    pub max_attempts: u32,
166    pub base_delay: Duration,
167    pub max_delay: Duration,
168}
169
170impl Default for RetryConfig {
171    fn default() -> Self {
172        Self {
173            max_attempts: 3,
174            base_delay: Duration::from_millis(200),
175            max_delay: Duration::from_secs(5),
176        }
177    }
178}
179
180#[derive(Debug, Clone)]
181pub struct TimeoutConfig {
182    pub per_request: Duration,
183}
184
185impl Default for TimeoutConfig {
186    fn default() -> Self {
187        Self {
188            per_request: Duration::from_secs(60),
189        }
190    }
191}
192
193// --- Provider trait (async, streaming) ---
194
195pub type EventStream = Pin<Box<dyn Stream<Item = Result<ApiEvent, ProviderError>> + Send>>;
196
197pub trait Provider: Send + Sync {
198    fn complete(
199        &self,
200        request: &CompletionRequest,
201    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>;
202
203    fn stream(&self, request: &CompletionRequest) -> EventStream;
204}
205
206// --- Boxed provider handle ---
207
208pub type BoxedProvider = Box<dyn Provider>;
209
210// --- Mock provider (streams with delays) ---
211
212#[derive(Debug, Default, Clone)]
213pub struct MockProvider;
214
215impl Provider for MockProvider {
216    fn complete(
217        &self,
218        request: &CompletionRequest,
219    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
220        let response = mock_complete_response(request);
221        Box::pin(async move { Ok(response) })
222    }
223
224    fn stream(&self, request: &CompletionRequest) -> EventStream {
225        let req = request.clone();
226        let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
227
228        tokio::spawn(async move {
229            let events = mock_stream_events(&req);
230            for (idx, event) in events.into_iter().enumerate() {
231                if idx > 0 {
232                    tokio::time::sleep(Duration::from_millis(18)).await;
233                }
234                if tx.send(Ok(event)).await.is_err() {
235                    return;
236                }
237            }
238        });
239
240        Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
241            rx.recv().await.map(|item| (item, rx))
242        }))
243    }
244}
245
246fn mock_complete_response(request: &CompletionRequest) -> CompletionResponse {
247    if wants_read_cargo_toml(&request.prompt) {
248        if let Some(tool_content) = first_tool_result_content(&request.messages) {
249            let response = mock_summarize_cargo_toml(&tool_content);
250            return CompletionResponse {
251                system_prompt: request.system_prompt_name.clone(),
252                response,
253                tool_count: request.tools.len(),
254                skill_count: request.skill_count,
255                mcp_server_count: request.mcp_servers.len(),
256            };
257        }
258
259        return CompletionResponse {
260            system_prompt: request.system_prompt_name.clone(),
261            response: "I'll read Cargo.toml first.".to_string(),
262            tool_count: request.tools.len(),
263            skill_count: request.skill_count,
264            mcp_server_count: request.mcp_servers.len(),
265        };
266    }
267
268    let response = mock_plain_reply(request);
269
270    CompletionResponse {
271        system_prompt: request.system_prompt_name.clone(),
272        response,
273        tool_count: request.tools.len(),
274        skill_count: request.skill_count,
275        mcp_server_count: request.mcp_servers.len(),
276    }
277}
278
279fn mock_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
280    if wants_read_cargo_toml(&request.prompt) {
281        if let Some(tool_content) = first_tool_result_content(&request.messages) {
282            return vec![
283                ApiEvent::ThinkingDelta {
284                    text: "I have the Cargo.toml contents; summarizing.".to_string(),
285                },
286                ApiEvent::MessageDelta {
287                    text: mock_summarize_cargo_toml(&tool_content),
288                },
289                ApiEvent::Usage {
290                    usage: UsageEvent {
291                        input_tokens: 220,
292                        output_tokens: 90,
293                        cache_read_tokens: 0,
294                        cache_write_tokens: 0,
295                    },
296                },
297                ApiEvent::Completed,
298            ];
299        }
300
301        return vec![
302            ApiEvent::ThinkingDelta {
303                text: "I should read Cargo.toml to answer this.".to_string(),
304            },
305            ApiEvent::ToolUse {
306                tool_use: ToolUseEvent {
307                    id: "tool_1".to_string(),
308                    name: "read_file".to_string(),
309                    input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
310                },
311            },
312            ApiEvent::Completed,
313        ];
314    }
315
316    vec![
317        ApiEvent::ThinkingDelta {
318            text: "Thinking about how to help...".to_string(),
319        },
320        ApiEvent::MessageDelta {
321            text: mock_plain_reply(request),
322        },
323        ApiEvent::Usage {
324            usage: UsageEvent {
325                input_tokens: 80,
326                output_tokens: 40,
327                cache_read_tokens: 0,
328                cache_write_tokens: 0,
329            },
330        },
331        ApiEvent::Completed,
332    ]
333}
334
335fn mock_plain_reply(request: &CompletionRequest) -> String {
336    let prompt = request.prompt.trim().to_ascii_lowercase();
337
338    if prompt.is_empty() {
339        return "Hello! How can I help you today?".to_string();
340    }
341
342    if ["hello", "hi", "hey"]
343        .iter()
344        .any(|greeting| prompt == *greeting)
345    {
346        return "Hello! How can I assist you today?".to_string();
347    }
348
349    if prompt.contains("how are you") {
350        return "I'm doing well, thank you! What can I help you with?".to_string();
351    }
352
353    "I received your message. To get started with real AI-powered assistance, configure a provider like Claude or Ollama.".to_string()
354}
355
356fn wants_read_cargo_toml(prompt: &str) -> bool {
357    let p = prompt.to_ascii_lowercase();
358    p.contains("cargo.toml") && (p.contains("read") || p.contains("summarize"))
359}
360
361fn first_tool_result_content(messages: &[ProviderMessage]) -> Option<String> {
362    for m in messages {
363        for b in &m.content {
364            if let ProviderContentBlock::ToolResult { content, .. } = b {
365                return Some(content.clone());
366            }
367        }
368    }
369    None
370}
371
372fn mock_summarize_cargo_toml(contents: &str) -> String {
373    if contents.contains("[workspace]") {
374        let mut out = String::from("Cargo.toml defines a Rust workspace.\n");
375        if contents.contains("members") {
376            out.push_str("It declares workspace members; this repo is a multi-crate workspace.\n");
377        }
378        out.push_str("Key crates include: clawedcode (cli), clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, clawedcode-tui.");
379        out
380    } else {
381        "Cargo.toml does not look like a workspace manifest (no [workspace] section).".to_string()
382    }
383}
384
385// --- MockToolProvider: deterministic provider for tests that triggers a tool call ---
386
387/// A mock provider that, on the first call, emits a `read_file` ToolUse for
388/// `Cargo.toml`. On subsequent calls (i.e. after the runtime has appended a
389/// tool_result message), it returns a final text response that references the
390/// tool result.
391///
392/// Detection of "subsequent call" is done by checking whether the session
393/// already contains a `tool` role message (injected by the runtime after tool
394/// execution).
395#[derive(Debug, Default, Clone)]
396pub struct MockToolProvider;
397
398impl Provider for MockToolProvider {
399    fn complete(
400        &self,
401        request: &CompletionRequest,
402    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>> {
403        let response = mock_tool_complete_response(request);
404        Box::pin(async move { Ok(response) })
405    }
406
407    fn stream(&self, request: &CompletionRequest) -> EventStream {
408        let req = request.clone();
409        let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(32);
410
411        tokio::spawn(async move {
412            let events = mock_tool_stream_events(&req);
413            for (idx, event) in events.into_iter().enumerate() {
414                if idx > 0 {
415                    tokio::time::sleep(Duration::from_millis(5)).await;
416                }
417                if tx.send(Ok(event)).await.is_err() {
418                    return;
419                }
420            }
421        });
422
423        Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
424            rx.recv().await.map(|item| (item, rx))
425        }))
426    }
427}
428
429fn mock_tool_complete_response(request: &CompletionRequest) -> CompletionResponse {
430    if has_tool_result_message(request) {
431        let response = "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string();
432        return CompletionResponse {
433            system_prompt: request.system_prompt_name.clone(),
434            response,
435            tool_count: request.tools.len(),
436            skill_count: request.skill_count,
437            mcp_server_count: request.mcp_servers.len(),
438        };
439    }
440
441    CompletionResponse {
442        system_prompt: request.system_prompt_name.clone(),
443        response: "I'll read the Cargo.toml file.".to_string(),
444        tool_count: request.tools.len(),
445        skill_count: request.skill_count,
446        mcp_server_count: request.mcp_servers.len(),
447    }
448}
449
450fn mock_tool_stream_events(request: &CompletionRequest) -> Vec<ApiEvent> {
451    if has_tool_result_message(request) {
452        return vec![
453            ApiEvent::ThinkingDelta {
454                text: "I have the file contents now.".to_string(),
455            },
456            ApiEvent::MessageDelta {
457                text: "Based on the Cargo.toml file, this is a Rust workspace named 'clawedcode' with multiple crates including clawedcode-cli, clawedcode-core, clawedcode-api, clawedcode-tools, clawedcode-mcp, and clawedcode-tui.".to_string(),
458            },
459            ApiEvent::Usage {
460                usage: UsageEvent {
461                    input_tokens: 200,
462                    output_tokens: 60,
463                    cache_read_tokens: 0,
464                    cache_write_tokens: 0,
465                },
466            },
467            ApiEvent::Completed,
468        ];
469    }
470
471    vec![
472        ApiEvent::ThinkingDelta {
473            text: "I should read the Cargo.toml file.".to_string(),
474        },
475        ApiEvent::ToolUse {
476            tool_use: ToolUseEvent {
477                id: "tool_1".to_string(),
478                name: "read_file".to_string(),
479                input: serde_json::json!({"path": "Cargo.toml"}).to_string(),
480            },
481        },
482        ApiEvent::Usage {
483            usage: UsageEvent {
484                input_tokens: 100,
485                output_tokens: 30,
486                cache_read_tokens: 0,
487                cache_write_tokens: 0,
488            },
489        },
490        ApiEvent::Completed,
491    ]
492}
493
494/// Returns true if the conversation already contains tool result blocks,
495/// indicating this is a re-query after tool execution.
496fn has_tool_result_message(request: &CompletionRequest) -> bool {
497    request.messages.iter().any(|m| {
498        m.content
499            .iter()
500            .any(|b| matches!(b, ProviderContentBlock::ToolResult { .. }))
501    })
502}
503
504// --- Optional Anthropic provider (behind feature flag + env var) ---
505
506#[cfg(feature = "anthropic")]
507pub mod anthropic_provider {
508    use super::*;
509    use reqwest::Client;
510
511    #[derive(Debug, Clone)]
512    pub struct AnthropicProvider {
513        client: Client,
514        api_key: String,
515        endpoint: String,
516        anthropic_version: String,
517        retry: RetryConfig,
518        timeout: TimeoutConfig,
519    }
520
521    pub(crate) fn normalize_anthropic_endpoint(endpoint: &str) -> String {
522        if endpoint.contains("/v1/messages") {
523            endpoint.to_string()
524        } else {
525            let endpoint = endpoint.trim_end_matches('/');
526            format!("{}/v1/messages", endpoint)
527        }
528    }
529
530    impl AnthropicProvider {
531        pub fn from_env() -> Option<Self> {
532            let api_key = std::env::var("ANTHROPIC_API_KEY")
533                .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
534                .ok()?;
535            let endpoint = std::env::var("CLAWEDCODE_ANTHROPIC_ENDPOINT")
536                .or_else(|_| std::env::var("ANTHROPIC_BASE_URL"))
537                .unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string());
538            let endpoint = normalize_anthropic_endpoint(&endpoint);
539            let anthropic_version = std::env::var("CLAWEDCODE_ANTHROPIC_VERSION")
540                .unwrap_or_else(|_| "2023-06-01".to_string());
541            Some(Self {
542                client: Client::new(),
543                api_key,
544                endpoint,
545                anthropic_version,
546                retry: RetryConfig::default(),
547                timeout: TimeoutConfig::default(),
548            })
549        }
550
551        pub fn new(api_key: String, endpoint: String) -> Self {
552            Self {
553                client: Client::new(),
554                api_key,
555                endpoint,
556                anthropic_version: "2023-06-01".to_string(),
557                retry: RetryConfig::default(),
558                timeout: TimeoutConfig::default(),
559            }
560        }
561    }
562
563    fn build_anthropic_tools(tools: &[ToolSpec]) -> serde_json::Value {
564        if tools.is_empty() {
565            return serde_json::Value::Null;
566        }
567        serde_json::Value::Array(
568            tools
569                .iter()
570                .map(|t| {
571                    serde_json::json!({
572                        "name": t.name,
573                        "description": t.description,
574                        "input_schema": t.input_schema,
575                    })
576                })
577                .collect(),
578        )
579    }
580
581    fn sanitize_messages_for_anthropic(messages: &[ProviderMessage]) -> Vec<ProviderMessage> {
582        messages
583            .iter()
584            .map(|m| ProviderMessage {
585                role: m.role.clone(),
586                content: m
587                    .content
588                    .iter()
589                    .filter(|b| !matches!(b, ProviderContentBlock::Thinking { .. }))
590                    .cloned()
591                    .collect(),
592            })
593            .filter(|m| !m.content.is_empty())
594            .collect()
595    }
596
597    impl Provider for AnthropicProvider {
598        fn complete(
599            &self,
600            request: &CompletionRequest,
601        ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, ProviderError>> + Send + '_>>
602        {
603            let req = request.clone();
604            let api_key = self.api_key.clone();
605            let client = self.client.clone();
606            let endpoint = self.endpoint.clone();
607            let anthropic_version = self.anthropic_version.clone();
608            let retry = self.retry.clone();
609            let timeout = self.timeout.clone();
610            Box::pin(async move {
611                let tools = build_anthropic_tools(&req.tools);
612                let messages = sanitize_messages_for_anthropic(&req.messages);
613                let messages = if messages.is_empty() {
614                    serde_json::json!([{"role": "user", "content": req.prompt}])
615                } else {
616                    serde_json::to_value(messages).unwrap_or_else(
617                        |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
618                    )
619                };
620                let mut body = serde_json::json!({
621                    "model": req.model,
622                    "max_tokens": 4096,
623                    "system": req.system_prompt_body,
624                    "messages": messages,
625                });
626                if !tools.is_null() {
627                    body["tools"] = tools;
628                }
629
630                let mut last_err: Option<ProviderError> = None;
631
632                for attempt in 1..=retry.max_attempts {
633                    let send_fut = client
634                        .post(&endpoint)
635                        .header("x-api-key", &api_key)
636                        .header("anthropic-version", &anthropic_version)
637                        .header("content-type", "application/json")
638                        .json(&body)
639                        .send();
640
641                    let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
642                        Ok(Ok(r)) => r,
643                        Ok(Err(e)) => {
644                            last_err = Some(ProviderError::Network {
645                                message: e.to_string(),
646                            });
647                            if attempt < retry.max_attempts {
648                                let backoff_ms = (retry.base_delay.as_millis() as u64)
649                                    .saturating_mul(1u64 << (attempt - 1));
650                                tokio::time::sleep(Duration::from_millis(
651                                    backoff_ms.min(retry.max_delay.as_millis() as u64),
652                                ))
653                                .await;
654                                continue;
655                            }
656                            break;
657                        }
658                        Err(_) => {
659                            last_err = Some(ProviderError::Timeout {
660                                elapsed_ms: timeout.per_request.as_millis() as u64,
661                            });
662                            if attempt < retry.max_attempts {
663                                tokio::time::sleep(retry.base_delay).await;
664                                continue;
665                            }
666                            break;
667                        }
668                    };
669
670                    let status = resp.status().as_u16();
671                    if !resp.status().is_success() {
672                        let text = resp.text().await.unwrap_or_default();
673                        let err = ProviderError::Api {
674                            status,
675                            message: text,
676                        };
677                        last_err = Some(err);
678
679                        // Retry 5xx; fail fast otherwise.
680                        let retryable = (500..=599).contains(&status);
681                        if retryable && attempt < retry.max_attempts {
682                            tokio::time::sleep(retry.base_delay).await;
683                            continue;
684                        }
685                        break;
686                    }
687
688                    let json: serde_json::Value =
689                        resp.json().await.map_err(|e| ProviderError::Parse {
690                            message: e.to_string(),
691                        })?;
692
693                    let response = json["content"]
694                        .as_array()
695                        .map(|arr| {
696                            arr.iter()
697                                .filter_map(|block| block["text"].as_str())
698                                .collect::<Vec<_>>()
699                                .join("")
700                        })
701                        .unwrap_or_default();
702
703                    return Ok(CompletionResponse {
704                        system_prompt: req.system_prompt_name.clone(),
705                        response,
706                        tool_count: req.tools.len(),
707                        skill_count: req.skill_count,
708                        mcp_server_count: req.mcp_servers.len(),
709                    });
710                }
711
712                Err(match last_err {
713                    Some(e) => ProviderError::RetryExhausted {
714                        attempts: retry.max_attempts,
715                        last_error: e.to_string(),
716                    },
717                    None => ProviderError::Other {
718                        message: "request failed".to_string(),
719                    },
720                })
721            })
722        }
723
724        fn stream(&self, request: &CompletionRequest) -> EventStream {
725            let req = request.clone();
726            let api_key = self.api_key.clone();
727            let client = self.client.clone();
728            let endpoint = self.endpoint.clone();
729            let anthropic_version = self.anthropic_version.clone();
730            let timeout = self.timeout.clone();
731
732            let (tx, rx) = mpsc::channel::<Result<ApiEvent, ProviderError>>(64);
733
734            tokio::spawn(async move {
735                let tools = build_anthropic_tools(&req.tools);
736                let messages = sanitize_messages_for_anthropic(&req.messages);
737                let messages = if messages.is_empty() {
738                    serde_json::json!([{"role": "user", "content": req.prompt}])
739                } else {
740                    serde_json::to_value(messages).unwrap_or_else(
741                        |_| serde_json::json!([{"role": "user", "content": req.prompt}]),
742                    )
743                };
744                let mut body = serde_json::json!({
745                    "model": req.model,
746                    "max_tokens": 4096,
747                    "system": req.system_prompt_body,
748                    "messages": messages,
749                    "stream": true,
750                });
751                if !tools.is_null() {
752                    body["tools"] = tools;
753                }
754
755                let send_fut = client
756                    .post(&endpoint)
757                    .header("x-api-key", &api_key)
758                    .header("anthropic-version", &anthropic_version)
759                    .header("content-type", "application/json")
760                    .json(&body)
761                    .send();
762
763                let resp = match tokio::time::timeout(timeout.per_request, send_fut).await {
764                    Ok(Ok(r)) => r,
765                    Ok(Err(e)) => {
766                        let _ = tx
767                            .send(Err(ProviderError::Network {
768                                message: e.to_string(),
769                            }))
770                            .await;
771                        return;
772                    }
773                    Err(_) => {
774                        let _ = tx
775                            .send(Err(ProviderError::Timeout {
776                                elapsed_ms: timeout.per_request.as_millis() as u64,
777                            }))
778                            .await;
779                        return;
780                    }
781                };
782
783                if !resp.status().is_success() {
784                    let status = resp.status().as_u16();
785                    let text = resp.text().await.unwrap_or_default();
786                    let _ = tx
787                        .send(Err(ProviderError::Api {
788                            status,
789                            message: text,
790                        }))
791                        .await;
792                    return;
793                }
794
795                let mut buf = String::new();
796                let mut bytes = resp.bytes_stream();
797                use futures_util::StreamExt;
798
799                let mut current_tool_use: Option<(String, String, String)> = None;
800
801                while let Some(chunk) = bytes.next().await {
802                    let chunk = match chunk {
803                        Ok(c) => c,
804                        Err(e) => {
805                            let _ = tx
806                                .send(Err(ProviderError::Network {
807                                    message: e.to_string(),
808                                }))
809                                .await;
810                            return;
811                        }
812                    };
813
814                    buf.push_str(&String::from_utf8_lossy(&chunk));
815
816                    // SSE frames are separated by a blank line.
817                    while let Some(idx) = buf.find("\n\n") {
818                        let frame: String = buf.drain(..(idx + 2)).collect();
819                        let mut data_lines = Vec::new();
820                        for line in frame.lines() {
821                            let line = line.trim();
822                            if let Some(rest) = line.strip_prefix("data:") {
823                                let payload = rest.trim();
824                                if !payload.is_empty() {
825                                    data_lines.push(payload.to_string());
826                                }
827                            }
828                        }
829
830                        if data_lines.is_empty() {
831                            continue;
832                        }
833
834                        let data = data_lines.join("\n");
835                        if data == "[DONE]" {
836                            continue;
837                        }
838
839                        let Ok(event) = serde_json::from_str::<serde_json::Value>(&data) else {
840                            continue;
841                        };
842                        let typ = event["type"].as_str().unwrap_or("");
843                        match typ {
844                            "content_block_start" => {
845                                let cb_type = event["content_block"]["type"].as_str().unwrap_or("");
846                                if cb_type == "text" {
847                                    if let Some(t) = event["content_block"]["text"].as_str() {
848                                        if !t.is_empty() {
849                                            let _ = tx
850                                                .send(Ok(ApiEvent::MessageDelta {
851                                                    text: t.to_string(),
852                                                }))
853                                                .await;
854                                        }
855                                    }
856                                } else if cb_type == "tool_use" {
857                                    let id = event["content_block"]["id"]
858                                        .as_str()
859                                        .unwrap_or("")
860                                        .to_string();
861                                    let name = event["content_block"]["name"]
862                                        .as_str()
863                                        .unwrap_or("")
864                                        .to_string();
865                                    let input = event["content_block"]["input"].clone();
866                                    let input_str = if input.is_null() {
867                                        String::new()
868                                    } else {
869                                        serde_json::to_string(&input).unwrap_or_default()
870                                    };
871                                    current_tool_use = Some((id, name, input_str));
872                                } else if cb_type == "thinking" {
873                                    if let Some(t) = event["content_block"]["thinking"].as_str() {
874                                        if !t.is_empty() {
875                                            let _ = tx
876                                                .send(Ok(ApiEvent::ThinkingDelta {
877                                                    text: t.to_string(),
878                                                }))
879                                                .await;
880                                        }
881                                    }
882                                }
883                            }
884                            "content_block_delta" => {
885                                let delta_type = event["delta"]["type"].as_str().unwrap_or("");
886                                match delta_type {
887                                    "text_delta" => {
888                                        if let Some(t) = event["delta"]["text"].as_str() {
889                                            let _ = tx
890                                                .send(Ok(ApiEvent::MessageDelta {
891                                                    text: t.to_string(),
892                                                }))
893                                                .await;
894                                        }
895                                    }
896                                    "thinking_delta" => {
897                                        if let Some(t) = event["delta"]["thinking"].as_str() {
898                                            let _ = tx
899                                                .send(Ok(ApiEvent::ThinkingDelta {
900                                                    text: t.to_string(),
901                                                }))
902                                                .await;
903                                        }
904                                    }
905                                    "input_json_delta" => {
906                                        if let Some(partial) =
907                                            event["delta"]["partial_json"].as_str()
908                                        {
909                                            if let Some((_id, _name, input_buf)) =
910                                                current_tool_use.as_mut()
911                                            {
912                                                input_buf.push_str(partial);
913                                            }
914                                        }
915                                    }
916                                    _ => {}
917                                }
918                            }
919                            "content_block_stop" => {
920                                if let Some((id, name, input)) = current_tool_use.take() {
921                                    let input = if input.trim().is_empty() {
922                                        "{}".to_string()
923                                    } else {
924                                        input
925                                    };
926                                    let _ = tx
927                                        .send(Ok(ApiEvent::ToolUse {
928                                            tool_use: ToolUseEvent { id, name, input },
929                                        }))
930                                        .await;
931                                }
932                            }
933                            "message_delta" => {
934                                if let Some(usage) = event.get("usage") {
935                                    let _ = tx
936                                        .send(Ok(ApiEvent::Usage {
937                                            usage: UsageEvent {
938                                                input_tokens: usage["input_tokens"]
939                                                    .as_u64()
940                                                    .unwrap_or(0),
941                                                output_tokens: usage["output_tokens"]
942                                                    .as_u64()
943                                                    .unwrap_or(0),
944                                                cache_read_tokens: usage["cache_read_input_tokens"]
945                                                    .as_u64()
946                                                    .unwrap_or(0),
947                                                cache_write_tokens:
948                                                    usage["cache_creation_input_tokens"]
949                                                        .as_u64()
950                                                        .unwrap_or(0),
951                                            },
952                                        }))
953                                        .await;
954                                }
955                            }
956                            "message_stop" => {
957                                let _ = tx.send(Ok(ApiEvent::Completed)).await;
958                                return;
959                            }
960                            _ => {}
961                        }
962                    }
963                }
964
965                let _ = tx.send(Ok(ApiEvent::Completed)).await;
966            });
967
968            Box::pin(futures_util::stream::unfold(rx, |mut rx| async move {
969                rx.recv().await.map(|item| (item, rx))
970            }))
971        }
972    }
973}
974
975// --- Provider factory (picks provider from env) ---
976
977pub fn create_provider() -> BoxedProvider {
978    let provider_name = std::env::var("CLAWEDCODE_PROVIDER").unwrap_or_default();
979
980    match provider_name.as_str() {
981        #[cfg(feature = "anthropic")]
982        "anthropic" => {
983            if let Some(p) = anthropic_provider::AnthropicProvider::from_env() {
984                tracing::info!("Using Anthropic provider");
985                return Box::new(p);
986            }
987            tracing::warn!("ANTHROPIC_API_KEY not set, falling back to mock provider");
988        }
989        #[cfg(not(feature = "anthropic"))]
990        "anthropic" => {
991            tracing::warn!(
992                "Anthropic provider requested but 'anthropic' feature not enabled, falling back to mock"
993            );
994        }
995        _ => {}
996    }
997
998    tracing::info!("Using mock provider");
999    Box::new(MockProvider)
1000}
1001
1002// --- Helper: collect stream into response ---
1003
1004pub async fn collect_stream_to_response(
1005    stream: EventStream,
1006    request: &CompletionRequest,
1007) -> Result<CompletionResponse, ProviderError> {
1008    use futures_util::StreamExt;
1009    let mut text = String::new();
1010    let mut thinking = String::new();
1011
1012    let mut s = stream;
1013    while let Some(event) = s.next().await {
1014        match event? {
1015            ApiEvent::MessageDelta { text: t } => text.push_str(&t),
1016            ApiEvent::ThinkingDelta { text: t } => thinking.push_str(&t),
1017            ApiEvent::Usage { usage: _ } => {}
1018            ApiEvent::Completed => break,
1019            ApiEvent::ToolUse { .. } | ApiEvent::ToolResult { .. } => {}
1020        }
1021    }
1022
1023    Ok(CompletionResponse {
1024        system_prompt: request.system_prompt_name.clone(),
1025        response: text,
1026        tool_count: request.tools.len(),
1027        skill_count: request.skill_count,
1028        mcp_server_count: request.mcp_servers.len(),
1029    })
1030}
1031
1032#[cfg(test)]
1033mod tests {
1034    use super::*;
1035    use futures_util::StreamExt;
1036
1037    #[tokio::test]
1038    async fn mock_stream_has_multiple_deltas() {
1039        let provider = MockProvider;
1040        let request = CompletionRequest {
1041            model: "test-model".to_string(),
1042            prompt_pack: "default".to_string(),
1043            system_prompt_name: "default".to_string(),
1044            system_prompt_body: "You are helpful.".to_string(),
1045            prompt: "hello".to_string(),
1046            messages: vec![],
1047            tools: vec![],
1048            skill_count: 0,
1049            mcp_servers: BTreeMap::new(),
1050        };
1051
1052        let events: Vec<_> = provider
1053            .stream(&request)
1054            .filter_map(|e| async move { e.ok() })
1055            .collect()
1056            .await;
1057        assert!(!events.is_empty());
1058
1059        assert!(matches!(events.last(), Some(ApiEvent::Completed)));
1060    }
1061
1062    #[tokio::test]
1063    async fn mock_stream_event_ordering() {
1064        let provider = MockProvider;
1065        let request = CompletionRequest {
1066            model: "test-model".to_string(),
1067            prompt_pack: "default".to_string(),
1068            system_prompt_name: "default".to_string(),
1069            system_prompt_body: "You are helpful.".to_string(),
1070            prompt: "hello".to_string(),
1071            messages: vec![],
1072            tools: vec![],
1073            skill_count: 0,
1074            mcp_servers: BTreeMap::new(),
1075        };
1076
1077        let events: Vec<_> = provider
1078            .stream(&request)
1079            .filter_map(|e| async move { e.ok() })
1080            .collect()
1081            .await;
1082
1083        assert!(matches!(events[0], ApiEvent::ThinkingDelta { .. }));
1084        assert!(matches!(events[1], ApiEvent::MessageDelta { .. }));
1085
1086        let usage_idx = events
1087            .iter()
1088            .position(|e| matches!(e, ApiEvent::Usage { .. }))
1089            .expect("Usage event should exist");
1090        let completed_idx = events
1091            .iter()
1092            .position(|e| matches!(e, ApiEvent::Completed))
1093            .expect("Completed event should exist");
1094        assert!(usage_idx < completed_idx);
1095    }
1096
1097    #[tokio::test]
1098    async fn mock_stream_concatenated_text_matches_complete_response() {
1099        let provider = MockProvider;
1100        let request = CompletionRequest {
1101            model: "test-model".to_string(),
1102            prompt_pack: "default".to_string(),
1103            system_prompt_name: "default".to_string(),
1104            system_prompt_body: "You are helpful.".to_string(),
1105            prompt: "hello".to_string(),
1106            messages: vec![],
1107            tools: vec![],
1108            skill_count: 0,
1109            mcp_servers: BTreeMap::new(),
1110        };
1111
1112        let direct = provider.complete(&request).await.unwrap();
1113        let events: Vec<_> = provider
1114            .stream(&request)
1115            .filter_map(|e| async move { e.ok() })
1116            .collect()
1117            .await;
1118
1119        let mut text = String::new();
1120        for e in &events {
1121            if let ApiEvent::MessageDelta { text: t } = e {
1122                text.push_str(t);
1123            }
1124        }
1125
1126        assert_eq!(text, direct.response);
1127    }
1128
1129    #[test]
1130    fn usage_account_accumulates() {
1131        let mut account = UsageAccount::default();
1132        account.record(&UsageEvent {
1133            input_tokens: 100,
1134            output_tokens: 50,
1135            cache_read_tokens: 10,
1136            cache_write_tokens: 20,
1137        });
1138        account.record(&UsageEvent {
1139            input_tokens: 200,
1140            output_tokens: 75,
1141            cache_read_tokens: 0,
1142            cache_write_tokens: 0,
1143        });
1144
1145        assert_eq!(account.total_input_tokens, 300);
1146        assert_eq!(account.total_output_tokens, 125);
1147        assert_eq!(account.total_cache_read_tokens, 10);
1148        assert_eq!(account.total_cache_write_tokens, 20);
1149        assert_eq!(account.request_count, 2);
1150    }
1151
1152    #[test]
1153    fn provider_error_display() {
1154        let err = ProviderError::Timeout { elapsed_ms: 5000 };
1155        assert!(err.to_string().contains("5000"));
1156
1157        let err = ProviderError::RetryExhausted {
1158            attempts: 3,
1159            last_error: "timeout".to_string(),
1160        };
1161        assert!(err.to_string().contains("3"));
1162    }
1163
1164    #[cfg(feature = "anthropic")]
1165    mod anthropic_tests {
1166        use crate::anthropic_provider::normalize_anthropic_endpoint;
1167
1168        #[test]
1169        fn test_normalize_anthropic_endpoint_base_url() {
1170            assert_eq!(
1171                normalize_anthropic_endpoint("http://localhost:11434"),
1172                "http://localhost:11434/v1/messages"
1173            );
1174        }
1175
1176        #[test]
1177        fn test_normalize_anthropic_endpoint_trailing_slash() {
1178            assert_eq!(
1179                normalize_anthropic_endpoint("http://localhost:11434/"),
1180                "http://localhost:11434/v1/messages"
1181            );
1182        }
1183
1184        #[test]
1185        fn test_normalize_anthropic_endpoint_already_has_v1() {
1186            assert_eq!(
1187                normalize_anthropic_endpoint("http://localhost:11434/v1/messages"),
1188                "http://localhost:11434/v1/messages"
1189            );
1190        }
1191
1192        #[test]
1193        fn test_normalize_anthropic_endpoint_custom_path() {
1194            assert_eq!(
1195                normalize_anthropic_endpoint("https://api.anthropic.com"),
1196                "https://api.anthropic.com/v1/messages"
1197            );
1198        }
1199    }
1200}