Skip to main content

claude_wrapper/
types.rs

1use std::fmt;
2use std::str::FromStr;
3
4use serde::{Deserialize, Serialize};
5
6/// Transport type for MCP server connections.
7///
8/// # Example
9///
10/// ```
11/// use claude_wrapper::Transport;
12/// use std::str::FromStr;
13///
14/// let t = Transport::from_str("stdio").unwrap();
15/// assert_eq!(t, Transport::Stdio);
16/// assert_eq!(t.to_string(), "stdio");
17///
18/// let t: Transport = "http".parse().unwrap();
19/// assert_eq!(t, Transport::Http);
20/// ```
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
22#[serde(rename_all = "lowercase")]
23pub enum Transport {
24    /// Standard I/O transport — server runs as a subprocess.
25    Stdio,
26    /// HTTP transport — server accessible via URL.
27    Http,
28    /// Server-Sent Events transport — server accessible via URL with SSE.
29    Sse,
30}
31
32impl fmt::Display for Transport {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            Self::Stdio => write!(f, "stdio"),
36            Self::Http => write!(f, "http"),
37            Self::Sse => write!(f, "sse"),
38        }
39    }
40}
41
42/// Error returned when parsing an unknown transport string.
43#[derive(Debug, Clone, thiserror::Error)]
44#[error("unknown transport: {0}")]
45pub struct TransportParseError(pub String);
46
47impl FromStr for Transport {
48    type Err = TransportParseError;
49
50    fn from_str(s: &str) -> Result<Self, Self::Err> {
51        match s {
52            "stdio" => Ok(Self::Stdio),
53            "http" => Ok(Self::Http),
54            "sse" => Ok(Self::Sse),
55            other => Err(TransportParseError(other.to_string())),
56        }
57    }
58}
59
60impl TryFrom<&str> for Transport {
61    type Error = TransportParseError;
62
63    fn try_from(s: &str) -> Result<Self, Self::Error> {
64        s.parse()
65    }
66}
67
68#[cfg(test)]
69mod transport_tests {
70    use super::*;
71
72    #[test]
73    fn from_str_accepts_known_variants() {
74        assert_eq!("stdio".parse::<Transport>().unwrap(), Transport::Stdio);
75        assert_eq!("http".parse::<Transport>().unwrap(), Transport::Http);
76        assert_eq!("sse".parse::<Transport>().unwrap(), Transport::Sse);
77    }
78
79    #[test]
80    fn from_str_returns_error_for_unknown() {
81        let err = "websocket".parse::<Transport>().unwrap_err();
82        assert!(err.to_string().contains("websocket"));
83    }
84
85    #[test]
86    fn try_from_str_returns_error_for_unknown() {
87        assert!(Transport::try_from("bogus").is_err());
88    }
89}
90
91/// Output format for `--output-format`.
92#[derive(Debug, Clone, Copy, Default)]
93pub enum OutputFormat {
94    /// Plain text output (default).
95    #[default]
96    Text,
97    /// Single JSON result object.
98    Json,
99    /// Streaming NDJSON.
100    StreamJson,
101}
102
103impl OutputFormat {
104    pub(crate) fn as_arg(&self) -> &'static str {
105        match self {
106            Self::Text => "text",
107            Self::Json => "json",
108            Self::StreamJson => "stream-json",
109        }
110    }
111}
112
113/// Permission mode for `--permission-mode`.
114#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
115#[serde(rename_all = "camelCase")]
116pub enum PermissionMode {
117    /// Default interactive permissions.
118    #[default]
119    Default,
120    /// Auto-accept file edits.
121    AcceptEdits,
122    /// Bypass all permission checks.
123    BypassPermissions,
124    /// Don't ask for permissions (deny by default).
125    DontAsk,
126    /// Plan mode (read-only).
127    Plan,
128    /// Auto mode.
129    Auto,
130}
131
132impl PermissionMode {
133    pub(crate) fn as_arg(&self) -> &'static str {
134        match self {
135            Self::Default => "default",
136            Self::AcceptEdits => "acceptEdits",
137            Self::BypassPermissions => "bypassPermissions",
138            Self::DontAsk => "dontAsk",
139            Self::Plan => "plan",
140            Self::Auto => "auto",
141        }
142    }
143}
144
145/// Input format for `--input-format`.
146#[derive(Debug, Clone, Copy, Default)]
147pub enum InputFormat {
148    /// Plain text input (default).
149    #[default]
150    Text,
151    /// Streaming JSON input.
152    StreamJson,
153}
154
155impl InputFormat {
156    pub(crate) fn as_arg(&self) -> &'static str {
157        match self {
158            Self::Text => "text",
159            Self::StreamJson => "stream-json",
160        }
161    }
162}
163
164/// Effort level for `--effort`.
165#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
166#[serde(rename_all = "lowercase")]
167pub enum Effort {
168    /// Low effort.
169    Low,
170    /// Medium effort (default).
171    Medium,
172    /// High effort.
173    High,
174    /// Maximum effort, most thorough.
175    Max,
176}
177
178impl Effort {
179    pub(crate) fn as_arg(&self) -> &'static str {
180        match self {
181            Self::Low => "low",
182            Self::Medium => "medium",
183            Self::High => "high",
184            Self::Max => "max",
185        }
186    }
187}
188
189/// Scope for MCP and plugin commands.
190#[derive(Debug, Clone, Copy, Default)]
191pub enum Scope {
192    /// Local scope (current directory).
193    #[default]
194    Local,
195    /// User scope (global).
196    User,
197    /// Project scope.
198    Project,
199}
200
201impl Scope {
202    pub(crate) fn as_arg(&self) -> &'static str {
203        match self {
204            Self::Local => "local",
205            Self::User => "user",
206            Self::Project => "project",
207        }
208    }
209}
210
211/// Authentication status returned by `claude auth status --json`.
212///
213/// # Example
214///
215/// ```no_run
216/// # async fn example() -> claude_wrapper::Result<()> {
217/// let claude = claude_wrapper::Claude::builder().build()?;
218/// let status = claude_wrapper::AuthStatusCommand::new()
219///     .execute_json(&claude).await?;
220///
221/// if status.logged_in {
222///     println!("Logged in as {}", status.email.unwrap_or_default());
223/// }
224/// # Ok(())
225/// # }
226/// ```
227#[cfg(feature = "json")]
228#[derive(Debug, Clone, Deserialize, Serialize)]
229#[serde(rename_all = "camelCase")]
230pub struct AuthStatus {
231    /// Whether the user is currently logged in.
232    #[serde(default)]
233    pub logged_in: bool,
234    /// Authentication method (e.g. "claude.ai").
235    #[serde(default)]
236    pub auth_method: Option<String>,
237    /// API provider (e.g. "firstParty").
238    #[serde(default)]
239    pub api_provider: Option<String>,
240    /// Authenticated user's email address.
241    #[serde(default)]
242    pub email: Option<String>,
243    /// Organization ID.
244    #[serde(default)]
245    pub org_id: Option<String>,
246    /// Organization name.
247    #[serde(default)]
248    pub org_name: Option<String>,
249    /// Subscription type (e.g. "pro", "max").
250    #[serde(default)]
251    pub subscription_type: Option<String>,
252    /// Any additional fields not explicitly modeled.
253    #[serde(flatten)]
254    pub extra: std::collections::HashMap<String, serde_json::Value>,
255}
256
257/// A message from a query result, representing one turn in the conversation.
258#[cfg(feature = "json")]
259#[derive(Debug, Clone, Deserialize, Serialize)]
260pub struct QueryMessage {
261    /// The role of the message sender (e.g., "user", "assistant").
262    #[serde(default)]
263    pub role: String,
264    /// The text content of the message.
265    #[serde(default)]
266    pub content: serde_json::Value,
267    /// Additional fields returned by the CLI not captured in typed fields.
268    #[serde(flatten)]
269    pub extra: std::collections::HashMap<String, serde_json::Value>,
270}
271
272/// Result from a query with `--output-format json`.
273#[cfg(feature = "json")]
274#[derive(Debug, Clone, Deserialize, Serialize)]
275pub struct QueryResult {
276    /// The text content of the query response.
277    #[serde(default)]
278    pub result: String,
279    /// The session ID for continuing conversations.
280    #[serde(default)]
281    pub session_id: String,
282    /// Total cost of the query in USD.
283    #[serde(default, rename = "total_cost_usd", alias = "cost_usd")]
284    pub cost_usd: Option<f64>,
285    /// Duration of the query in milliseconds.
286    #[serde(default)]
287    pub duration_ms: Option<u64>,
288    /// Number of conversation turns in the query.
289    #[serde(default)]
290    pub num_turns: Option<u32>,
291    /// Whether the query resulted in an error.
292    #[serde(default)]
293    pub is_error: bool,
294    /// Additional fields returned by the CLI not captured in typed fields.
295    #[serde(flatten)]
296    pub extra: std::collections::HashMap<String, serde_json::Value>,
297}
298
299#[cfg(all(test, feature = "json"))]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn query_result_deserializes_total_cost_usd() {
305        let json =
306            r#"{"result":"hello","session_id":"s1","total_cost_usd":0.042,"is_error":false}"#;
307        let qr: QueryResult = serde_json::from_str(json).unwrap();
308        assert_eq!(qr.cost_usd, Some(0.042));
309    }
310
311    #[test]
312    fn query_result_deserializes_cost_usd_alias() {
313        let json = r#"{"result":"hello","session_id":"s1","cost_usd":0.01,"is_error":false}"#;
314        let qr: QueryResult = serde_json::from_str(json).unwrap();
315        assert_eq!(qr.cost_usd, Some(0.01));
316    }
317
318    #[test]
319    fn query_result_missing_cost_defaults_to_none() {
320        let json = r#"{"result":"hello","session_id":"s1","is_error":false}"#;
321        let qr: QueryResult = serde_json::from_str(json).unwrap();
322        assert_eq!(qr.cost_usd, None);
323    }
324
325    #[test]
326    fn query_result_from_stream_result_event() {
327        // Exact shape of a --output-format stream-json "result" event.
328        // The flattened `extra` must absorb type/subtype without
329        // breaking typed field parsing.
330        let json = r#"{"type":"result","subtype":"success","result":"streamed","session_id":"sess-1","total_cost_usd":0.03,"num_turns":1,"is_error":false}"#;
331        let qr: QueryResult = serde_json::from_str(json).unwrap();
332        assert_eq!(qr.cost_usd, Some(0.03));
333        assert_eq!(qr.num_turns, Some(1));
334        assert_eq!(qr.session_id, "sess-1");
335        assert_eq!(qr.result, "streamed");
336        assert_eq!(
337            qr.extra.get("type").and_then(|v| v.as_str()),
338            Some("result")
339        );
340    }
341
342    #[test]
343    fn query_result_deserializes_num_turns() {
344        let json = r#"{"result":"done","session_id":"s2","total_cost_usd":0.1,"num_turns":5,"is_error":false}"#;
345        let qr: QueryResult = serde_json::from_str(json).unwrap();
346        assert_eq!(qr.num_turns, Some(5));
347        assert_eq!(qr.cost_usd, Some(0.1));
348    }
349
350    #[test]
351    fn query_result_serializes_as_total_cost_usd() {
352        let qr = QueryResult {
353            result: "ok".into(),
354            session_id: "s1".into(),
355            cost_usd: Some(0.05),
356            duration_ms: None,
357            num_turns: Some(3),
358            is_error: false,
359            extra: Default::default(),
360        };
361        let json = serde_json::to_string(&qr).unwrap();
362        assert!(json.contains("\"total_cost_usd\""));
363        assert!(json.contains("\"num_turns\""));
364    }
365}