Skip to main content

nenjo_tool_api/
async_ops.rs

1//! Shared contracts for model-visible async operation tools.
2//!
3//! Runtime crates own operation scheduling, cancellation, polling, and event
4//! delivery. This module only defines stable tool names, argument DTOs, JSON
5//! schemas, and operation lifecycle enums.
6
7use serde::{Deserialize, Deserializer, Serialize};
8use serde_json::json;
9
10pub const WAIT_OPERATIONS_TOOL_NAME: &str = "wait_operations";
11pub const INSPECT_OPERATIONS_TOOL_NAME: &str = "inspect_operations";
12pub const STOP_OPERATIONS_TOOL_NAME: &str = "stop_operations";
13pub const SEND_OPERATION_INPUT_TOOL_NAME: &str = "send_operation_input";
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum AsyncOperationKind {
18    Ability,
19    Delegation,
20    SubAgent,
21    Shell,
22    Media,
23}
24
25impl AsyncOperationKind {
26    pub fn as_str(self) -> &'static str {
27        match self {
28            Self::Ability => "ability",
29            Self::Delegation => "delegation",
30            Self::SubAgent => "sub_agent",
31            Self::Shell => "shell",
32            Self::Media => "media",
33        }
34    }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum AsyncOperationStatus {
40    Running,
41    WaitingForInput,
42    Completed,
43    Failed,
44    Stopped,
45}
46
47impl AsyncOperationStatus {
48    pub fn as_str(self) -> &'static str {
49        match self {
50            Self::Running => "running",
51            Self::WaitingForInput => "waiting_for_input",
52            Self::Completed => "completed",
53            Self::Failed => "failed",
54            Self::Stopped => "stopped",
55        }
56    }
57
58    pub fn can_receive_input(self) -> bool {
59        matches!(self, Self::Running | Self::WaitingForInput)
60    }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
64#[serde(rename_all = "snake_case")]
65pub enum AsyncOperationSignalKind {
66    Started,
67    Progress,
68    NeedsInput,
69    Completed,
70    Failed,
71    Stopped,
72}
73
74impl AsyncOperationSignalKind {
75    pub fn as_str(self) -> &'static str {
76        match self {
77            Self::Started => "started",
78            Self::Progress => "progress",
79            Self::NeedsInput => "needs_input",
80            Self::Completed => "completed",
81            Self::Failed => "failed",
82            Self::Stopped => "stopped",
83        }
84    }
85}
86
87#[derive(Debug, Clone, Deserialize)]
88pub struct InspectOperationsArgs {
89    #[serde(default)]
90    pub operations: Vec<String>,
91    #[serde(default)]
92    pub kind: Option<AsyncOperationKind>,
93    #[serde(default)]
94    pub include_transcript: bool,
95    #[serde(
96        default = "default_inspect_limit",
97        deserialize_with = "deserialize_usize_from_json_number"
98    )]
99    pub limit: usize,
100}
101
102#[derive(Debug, Clone, Deserialize)]
103pub struct StopOperationsArgs {
104    #[serde(default)]
105    pub operations: Vec<String>,
106    #[serde(default)]
107    pub kind: Option<AsyncOperationKind>,
108    pub reason: Option<String>,
109}
110
111#[derive(Debug, Clone, Deserialize)]
112pub struct WaitOperationsArgs {
113    #[serde(default = "default_wait_seconds")]
114    pub seconds: u64,
115    #[serde(default)]
116    pub kind: Option<AsyncOperationKind>,
117    pub reason: Option<String>,
118}
119
120#[derive(Debug, Clone, Deserialize)]
121pub struct SendOperationInputArgs {
122    #[serde(default)]
123    pub operations: Vec<String>,
124    pub message: String,
125}
126
127pub fn inspect_operations_parameters_schema() -> serde_json::Value {
128    json!({
129        "type": "object",
130        "properties": {
131            "operations": {"type": "array", "items": {"type": "string"}},
132            "kind": operation_kind_schema(),
133            "include_transcript": {"type": "boolean"},
134            "limit": {"type": "integer", "minimum": 1, "maximum": 50}
135        },
136        "additionalProperties": false
137    })
138}
139
140pub fn stop_operations_parameters_schema() -> serde_json::Value {
141    json!({
142        "type": "object",
143        "properties": {
144            "operations": {"type": "array", "items": {"type": "string"}},
145            "kind": operation_kind_schema(),
146            "reason": {"type": "string"}
147        },
148        "additionalProperties": false
149    })
150}
151
152pub fn wait_operations_parameters_schema() -> serde_json::Value {
153    json!({
154        "type": "object",
155        "properties": {
156            "seconds": {"type": "number", "minimum": 1, "maximum": 30},
157            "kind": operation_kind_schema(),
158            "reason": {"type": "string"}
159        },
160        "additionalProperties": false
161    })
162}
163
164pub fn send_operation_input_parameters_schema() -> serde_json::Value {
165    json!({
166        "type": "object",
167        "properties": {
168            "operations": {"type": "array", "items": {"type": "string"}},
169            "message": {"type": "string"}
170        },
171        "required": ["operations", "message"],
172        "additionalProperties": false
173    })
174}
175
176pub fn operation_kind_schema() -> serde_json::Value {
177    json!({
178        "type": "string",
179        "enum": ["ability", "delegation", "sub_agent", "shell", "media"]
180    })
181}
182
183fn default_inspect_limit() -> usize {
184    30
185}
186
187pub fn deserialize_usize_from_json_number<'de, D>(deserializer: D) -> Result<usize, D::Error>
188where
189    D: Deserializer<'de>,
190{
191    let value = serde_json::Value::deserialize(deserializer)?;
192    match value {
193        serde_json::Value::Number(number) => {
194            if let Some(raw) = number.as_u64() {
195                usize::try_from(raw).map_err(serde::de::Error::custom)
196            } else if let Some(raw) = number.as_f64() {
197                if raw.is_finite() && raw.fract() == 0.0 && raw >= 0.0 {
198                    usize::try_from(raw as u64).map_err(serde::de::Error::custom)
199                } else {
200                    Err(serde::de::Error::custom(
201                        "expected a non-negative whole number",
202                    ))
203                }
204            } else {
205                Err(serde::de::Error::custom(
206                    "expected a non-negative whole number",
207                ))
208            }
209        }
210        other => Err(serde::de::Error::custom(format!(
211            "expected a non-negative whole number, got {other}"
212        ))),
213    }
214}
215
216fn default_wait_seconds() -> u64 {
217    10
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn async_operation_kind_uses_wire_names() {
226        assert_eq!(AsyncOperationKind::Ability.as_str(), "ability");
227        assert_eq!(AsyncOperationKind::Delegation.as_str(), "delegation");
228        assert_eq!(AsyncOperationKind::SubAgent.as_str(), "sub_agent");
229        assert_eq!(AsyncOperationKind::Shell.as_str(), "shell");
230        assert_eq!(AsyncOperationKind::Media.as_str(), "media");
231    }
232
233    #[test]
234    fn wait_args_deserialize_with_defaults() {
235        let args: WaitOperationsArgs = serde_json::from_value(json!({})).unwrap();
236
237        assert_eq!(args.seconds, 10);
238        assert_eq!(args.kind, None);
239    }
240
241    #[test]
242    fn inspect_args_accept_whole_float_limit_from_model_args() {
243        let args: InspectOperationsArgs = serde_json::from_value(json!({
244            "operations": ["ability_build_agent_2"],
245            "include_transcript": true,
246            "limit": 5.0
247        }))
248        .unwrap();
249
250        assert_eq!(args.limit, 5);
251    }
252
253    #[test]
254    fn inspect_args_reject_fractional_limit() {
255        let err = serde_json::from_value::<InspectOperationsArgs>(json!({
256            "limit": 5.5
257        }))
258        .unwrap_err();
259
260        assert!(err.to_string().contains("whole number"));
261    }
262}