Skip to main content

claude_wrapper/
types.rs

1use std::fmt;
2use std::str::FromStr;
3
4use serde::{Deserialize, Serialize};
5
6/// Transport type for MCP server connections.
7///
8/// # Example
9///
10/// ```
11/// use claude_wrapper::Transport;
12/// use std::str::FromStr;
13///
14/// let t = Transport::from_str("stdio").unwrap();
15/// assert_eq!(t, Transport::Stdio);
16/// assert_eq!(t.to_string(), "stdio");
17///
18/// let t: Transport = "http".parse().unwrap();
19/// assert_eq!(t, Transport::Http);
20/// ```
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
22#[serde(rename_all = "lowercase")]
23pub enum Transport {
24    /// Standard I/O transport — server runs as a subprocess.
25    Stdio,
26    /// HTTP transport — server accessible via URL.
27    Http,
28    /// Server-Sent Events transport — server accessible via URL with SSE.
29    Sse,
30}
31
32impl fmt::Display for Transport {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            Self::Stdio => write!(f, "stdio"),
36            Self::Http => write!(f, "http"),
37            Self::Sse => write!(f, "sse"),
38        }
39    }
40}
41
42/// Error returned when parsing an unknown transport string.
43#[derive(Debug, Clone, thiserror::Error)]
44#[error("unknown transport: {0}")]
45pub struct TransportParseError(pub String);
46
47impl FromStr for Transport {
48    type Err = TransportParseError;
49
50    fn from_str(s: &str) -> Result<Self, Self::Err> {
51        match s {
52            "stdio" => Ok(Self::Stdio),
53            "http" => Ok(Self::Http),
54            "sse" => Ok(Self::Sse),
55            other => Err(TransportParseError(other.to_string())),
56        }
57    }
58}
59
60impl From<&str> for Transport {
61    /// Convert a string slice to a `Transport`.
62    ///
63    /// # Panics
64    ///
65    /// Panics if the string is not a recognized transport type.
66    /// Valid values are: `"stdio"`, `"http"`, `"sse"`.
67    fn from(s: &str) -> Self {
68        s.parse().unwrap_or_else(|e| panic!("{e}"))
69    }
70}
71
72/// Output format for `--output-format`.
73#[derive(Debug, Clone, Copy, Default)]
74pub enum OutputFormat {
75    /// Plain text output (default).
76    #[default]
77    Text,
78    /// Single JSON result object.
79    Json,
80    /// Streaming NDJSON.
81    StreamJson,
82}
83
84impl OutputFormat {
85    pub(crate) fn as_arg(&self) -> &'static str {
86        match self {
87            Self::Text => "text",
88            Self::Json => "json",
89            Self::StreamJson => "stream-json",
90        }
91    }
92}
93
94/// Permission mode for `--permission-mode`.
95#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
96#[serde(rename_all = "camelCase")]
97pub enum PermissionMode {
98    /// Default interactive permissions.
99    #[default]
100    Default,
101    /// Auto-accept file edits.
102    AcceptEdits,
103    /// Bypass all permission checks.
104    BypassPermissions,
105    /// Don't ask for permissions (deny by default).
106    DontAsk,
107    /// Plan mode (read-only).
108    Plan,
109    /// Auto mode.
110    Auto,
111}
112
113impl PermissionMode {
114    pub(crate) fn as_arg(&self) -> &'static str {
115        match self {
116            Self::Default => "default",
117            Self::AcceptEdits => "acceptEdits",
118            Self::BypassPermissions => "bypassPermissions",
119            Self::DontAsk => "dontAsk",
120            Self::Plan => "plan",
121            Self::Auto => "auto",
122        }
123    }
124}
125
126/// Input format for `--input-format`.
127#[derive(Debug, Clone, Copy, Default)]
128pub enum InputFormat {
129    /// Plain text input (default).
130    #[default]
131    Text,
132    /// Streaming JSON input.
133    StreamJson,
134}
135
136impl InputFormat {
137    pub(crate) fn as_arg(&self) -> &'static str {
138        match self {
139            Self::Text => "text",
140            Self::StreamJson => "stream-json",
141        }
142    }
143}
144
145/// Effort level for `--effort`.
146#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
147#[serde(rename_all = "lowercase")]
148pub enum Effort {
149    /// Low effort.
150    Low,
151    /// Medium effort (default).
152    Medium,
153    /// High effort.
154    High,
155    /// Maximum effort, most thorough.
156    Max,
157}
158
159impl Effort {
160    pub(crate) fn as_arg(&self) -> &'static str {
161        match self {
162            Self::Low => "low",
163            Self::Medium => "medium",
164            Self::High => "high",
165            Self::Max => "max",
166        }
167    }
168}
169
170/// Scope for MCP and plugin commands.
171#[derive(Debug, Clone, Copy, Default)]
172pub enum Scope {
173    /// Local scope (current directory).
174    #[default]
175    Local,
176    /// User scope (global).
177    User,
178    /// Project scope.
179    Project,
180}
181
182impl Scope {
183    pub(crate) fn as_arg(&self) -> &'static str {
184        match self {
185            Self::Local => "local",
186            Self::User => "user",
187            Self::Project => "project",
188        }
189    }
190}
191
192/// Authentication status returned by `claude auth status --json`.
193///
194/// # Example
195///
196/// ```no_run
197/// # async fn example() -> claude_wrapper::Result<()> {
198/// let claude = claude_wrapper::Claude::builder().build()?;
199/// let status = claude_wrapper::AuthStatusCommand::new()
200///     .execute_json(&claude).await?;
201///
202/// if status.logged_in {
203///     println!("Logged in as {}", status.email.unwrap_or_default());
204/// }
205/// # Ok(())
206/// # }
207/// ```
208#[cfg(feature = "json")]
209#[derive(Debug, Clone, Deserialize, Serialize)]
210#[serde(rename_all = "camelCase")]
211pub struct AuthStatus {
212    /// Whether the user is currently logged in.
213    #[serde(default)]
214    pub logged_in: bool,
215    /// Authentication method (e.g. "claude.ai").
216    #[serde(default)]
217    pub auth_method: Option<String>,
218    /// API provider (e.g. "firstParty").
219    #[serde(default)]
220    pub api_provider: Option<String>,
221    /// Authenticated user's email address.
222    #[serde(default)]
223    pub email: Option<String>,
224    /// Organization ID.
225    #[serde(default)]
226    pub org_id: Option<String>,
227    /// Organization name.
228    #[serde(default)]
229    pub org_name: Option<String>,
230    /// Subscription type (e.g. "pro", "max").
231    #[serde(default)]
232    pub subscription_type: Option<String>,
233    /// Any additional fields not explicitly modeled.
234    #[serde(flatten)]
235    pub extra: std::collections::HashMap<String, serde_json::Value>,
236}
237
238/// A message from a query result, representing one turn in the conversation.
239#[cfg(feature = "json")]
240#[derive(Debug, Clone, Deserialize, Serialize)]
241pub struct QueryMessage {
242    /// The role of the message sender (e.g., "user", "assistant").
243    #[serde(default)]
244    pub role: String,
245    /// The text content of the message.
246    #[serde(default)]
247    pub content: serde_json::Value,
248    /// Additional fields returned by the CLI not captured in typed fields.
249    #[serde(flatten)]
250    pub extra: std::collections::HashMap<String, serde_json::Value>,
251}
252
253/// Result from a query with `--output-format json`.
254#[cfg(feature = "json")]
255#[derive(Debug, Clone, Deserialize, Serialize)]
256pub struct QueryResult {
257    /// The text content of the query response.
258    #[serde(default)]
259    pub result: String,
260    /// The session ID for continuing conversations.
261    #[serde(default)]
262    pub session_id: String,
263    /// Total cost of the query in USD.
264    #[serde(default, rename = "total_cost_usd", alias = "cost_usd")]
265    pub cost_usd: Option<f64>,
266    /// Duration of the query in milliseconds.
267    #[serde(default)]
268    pub duration_ms: Option<u64>,
269    /// Number of conversation turns in the query.
270    #[serde(default)]
271    pub num_turns: Option<u32>,
272    /// Whether the query resulted in an error.
273    #[serde(default)]
274    pub is_error: bool,
275    /// Additional fields returned by the CLI not captured in typed fields.
276    #[serde(flatten)]
277    pub extra: std::collections::HashMap<String, serde_json::Value>,
278}
279
280#[cfg(all(test, feature = "json"))]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn query_result_deserializes_total_cost_usd() {
286        let json =
287            r#"{"result":"hello","session_id":"s1","total_cost_usd":0.042,"is_error":false}"#;
288        let qr: QueryResult = serde_json::from_str(json).unwrap();
289        assert_eq!(qr.cost_usd, Some(0.042));
290    }
291
292    #[test]
293    fn query_result_deserializes_cost_usd_alias() {
294        let json = r#"{"result":"hello","session_id":"s1","cost_usd":0.01,"is_error":false}"#;
295        let qr: QueryResult = serde_json::from_str(json).unwrap();
296        assert_eq!(qr.cost_usd, Some(0.01));
297    }
298
299    #[test]
300    fn query_result_missing_cost_defaults_to_none() {
301        let json = r#"{"result":"hello","session_id":"s1","is_error":false}"#;
302        let qr: QueryResult = serde_json::from_str(json).unwrap();
303        assert_eq!(qr.cost_usd, None);
304    }
305
306    #[test]
307    fn query_result_deserializes_num_turns() {
308        let json = r#"{"result":"done","session_id":"s2","total_cost_usd":0.1,"num_turns":5,"is_error":false}"#;
309        let qr: QueryResult = serde_json::from_str(json).unwrap();
310        assert_eq!(qr.num_turns, Some(5));
311        assert_eq!(qr.cost_usd, Some(0.1));
312    }
313
314    #[test]
315    fn query_result_serializes_as_total_cost_usd() {
316        let qr = QueryResult {
317            result: "ok".into(),
318            session_id: "s1".into(),
319            cost_usd: Some(0.05),
320            duration_ms: None,
321            num_turns: Some(3),
322            is_error: false,
323            extra: Default::default(),
324        };
325        let json = serde_json::to_string(&qr).unwrap();
326        assert!(json.contains("\"total_cost_usd\""));
327        assert!(json.contains("\"num_turns\""));
328    }
329}