1use std::time::Duration;
2
3use crate::plugin::{BlockReason, HookPhase};
4use serde_json::{Map, Value};
5
6use crate::runtime::core::Runtime;
7use crate::runtime::errors::RpcError;
8use crate::runtime::hooks::RuntimeHookConfig;
9use crate::runtime::rpc_contract::{methods, RpcValidationMode};
10use crate::runtime::turn_output::{parse_thread_id, parse_turn_id};
11
12use super::flow::{
13 apply_pre_hook_actions_to_session, result_status, HookContextInput, HookExecutionState,
14 SessionMutationState,
15};
16use super::wire::{
17 deserialize_result, input_item_to_wire, serialize_params, thread_overrides_to_wire,
18 turn_start_params_to_wire, validate_turn_start_security,
19};
20use super::*;
21
22impl ThreadHandle {
23 pub fn runtime(&self) -> &crate::runtime::core::Runtime {
24 &self.runtime
25 }
26
27 pub async fn turn_start(&self, p: TurnStartParams) -> Result<TurnHandle, RpcError> {
28 ensure_turn_input_not_empty(&p.input)?;
29 validate_turn_start_security(&p)?;
30
31 let response = self
32 .runtime
33 .call_validated(
34 methods::TURN_START,
35 turn_start_params_to_wire(&self.thread_id, &p),
36 )
37 .await?;
38
39 let turn_id = parse_turn_id(&response).ok_or_else(|| {
40 RpcError::InvalidRequest(format!("turn/start missing turn id in result: {response}"))
41 })?;
42
43 Ok(TurnHandle {
44 turn_id,
45 thread_id: self.thread_id.clone(),
46 })
47 }
48
49 pub async fn turn_steer(
53 &self,
54 expected_turn_id: &str,
55 input: Vec<InputItem>,
56 ) -> Result<super::TurnId, RpcError> {
57 ensure_turn_input_not_empty(&input)?;
58
59 let mut params = Map::<String, Value>::new();
60 params.insert("threadId".to_owned(), Value::String(self.thread_id.clone()));
61 params.insert(
62 "expectedTurnId".to_owned(),
63 Value::String(expected_turn_id.to_owned()),
64 );
65 params.insert(
66 "input".to_owned(),
67 Value::Array(input.iter().map(input_item_to_wire).collect()),
68 );
69 let response = self
70 .runtime
71 .call_validated(methods::TURN_START, Value::Object(params))
72 .await?;
73 parse_turn_id(&response).ok_or_else(|| {
74 RpcError::InvalidRequest(format!(
75 "turn/start(steer) missing turn id in result: {response}"
76 ))
77 })
78 }
79
80 pub async fn turn_interrupt(&self, turn_id: &str) -> Result<(), RpcError> {
81 self.runtime.turn_interrupt(&self.thread_id, turn_id).await
82 }
83}
84
85impl Runtime {
86 pub(crate) fn loaded_thread_handle(&self, thread_id: &str) -> ThreadHandle {
87 ThreadHandle {
88 thread_id: thread_id.to_owned(),
89 runtime: self.clone(),
90 }
91 }
92
93 pub async fn thread_start(&self, p: ThreadStartParams) -> Result<ThreadHandle, RpcError> {
94 self.thread_start_with_hooks(p, None).await
95 }
96
97 pub(crate) async fn thread_start_with_hooks(
98 &self,
99 p: ThreadStartParams,
100 scoped_hooks: Option<&RuntimeHookConfig>,
101 ) -> Result<ThreadHandle, RpcError> {
102 if !self.hooks_enabled_with(scoped_hooks) {
103 return self.thread_start_raw(p).await;
104 }
105
106 let (p, mut hook_state, start_cwd, start_model) = self
107 .prepare_session_start_hooks(p, None, scoped_hooks)
108 .await?;
109 let result = self.thread_start_raw(p).await;
110 self.finalize_session_start_hooks(
111 &mut hook_state,
112 start_cwd.as_deref(),
113 start_model.as_deref(),
114 None,
115 &result,
116 scoped_hooks,
117 )
118 .await;
119 self.publish_hook_report(hook_state.report);
120 result
121 }
122
123 pub async fn thread_resume(
124 &self,
125 thread_id: &str,
126 p: ThreadStartParams,
127 ) -> Result<ThreadHandle, RpcError> {
128 self.thread_resume_with_hooks(thread_id, p, None).await
129 }
130
131 pub(crate) async fn thread_resume_with_hooks(
132 &self,
133 thread_id: &str,
134 p: ThreadStartParams,
135 scoped_hooks: Option<&RuntimeHookConfig>,
136 ) -> Result<ThreadHandle, RpcError> {
137 if !self.hooks_enabled_with(scoped_hooks) {
138 return self.thread_resume_raw(thread_id, p).await;
139 }
140
141 let (p, mut hook_state, resume_cwd, resume_model) = self
142 .prepare_session_start_hooks(p, Some(thread_id), scoped_hooks)
143 .await?;
144 let result = self.thread_resume_raw(thread_id, p).await;
145 self.finalize_session_start_hooks(
146 &mut hook_state,
147 resume_cwd.as_deref(),
148 resume_model.as_deref(),
149 Some(thread_id),
150 &result,
151 scoped_hooks,
152 )
153 .await;
154 self.publish_hook_report(hook_state.report);
155 result
156 }
157
158 async fn prepare_session_start_hooks(
159 &self,
160 p: ThreadStartParams,
161 thread_id: Option<&str>,
162 scoped_hooks: Option<&RuntimeHookConfig>,
163 ) -> Result<
164 (
165 ThreadStartParams,
166 HookExecutionState,
167 Option<String>,
168 Option<String>,
169 ),
170 RpcError,
171 > {
172 let mut hook_state = HookExecutionState::new(self.next_hook_correlation_id());
173 let mut session_state =
174 SessionMutationState::from_thread_start(&p, hook_state.metadata.clone());
175 let decisions = self
176 .execute_pre_hook_phase(
177 &mut hook_state,
178 HookPhase::PreSessionStart,
179 p.cwd.as_deref(),
180 p.model.as_deref(),
181 thread_id,
182 None,
183 scoped_hooks,
184 )
185 .await
186 .map_err(block_reason_to_rpc_error)?;
187 apply_pre_hook_actions_to_session(
188 &mut session_state,
189 HookPhase::PreSessionStart,
190 decisions,
191 &mut hook_state.report,
192 );
193 hook_state.metadata = session_state.metadata.clone();
194
195 let mut p = p;
196 p.model = session_state.model;
197 let cwd = p.cwd.clone();
198 let model = p.model.clone();
199
200 Ok((p, hook_state, cwd, model))
201 }
202
203 async fn finalize_session_start_hooks(
204 &self,
205 hook_state: &mut HookExecutionState,
206 cwd: Option<&str>,
207 model: Option<&str>,
208 fallback_thread_id: Option<&str>,
209 result: &Result<ThreadHandle, RpcError>,
210 scoped_hooks: Option<&RuntimeHookConfig>,
211 ) {
212 let post_thread_id = result
213 .as_ref()
214 .ok()
215 .map(|thread| thread.thread_id.as_str())
216 .or(fallback_thread_id);
217 self.execute_post_hook_phase(
218 hook_state,
219 HookContextInput {
220 phase: HookPhase::PostSessionStart,
221 cwd,
222 model,
223 thread_id: post_thread_id,
224 turn_id: None,
225 main_status: Some(result_status(result)),
226 },
227 scoped_hooks,
228 )
229 .await;
230 }
231
232 pub(crate) async fn thread_resume_raw(
233 &self,
234 thread_id: &str,
235 p: ThreadStartParams,
236 ) -> Result<ThreadHandle, RpcError> {
237 let p = super::escalate_approval_if_tool_hooks(self, p);
238 super::wire::validate_thread_start_security(&p)?;
239 let mut params = Map::<String, Value>::new();
240 params.insert("threadId".to_owned(), Value::String(thread_id.to_owned()));
241 let overrides = thread_overrides_to_wire(&p);
242 if !overrides.is_empty() {
243 params.insert("overrides".to_owned(), Value::Object(overrides));
244 }
245
246 let response = self
247 .call_validated(methods::THREAD_RESUME, Value::Object(params))
248 .await?;
249 let resumed = parse_thread_id(&response).ok_or_else(|| {
250 RpcError::InvalidRequest(format!(
251 "thread/resume missing thread id in result: {response}"
252 ))
253 })?;
254 if resumed != thread_id {
255 return Err(RpcError::InvalidRequest(format!(
256 "thread/resume returned mismatched thread id: requested={thread_id} actual={resumed}"
257 )));
258 }
259 Ok(ThreadHandle {
260 thread_id: resumed,
261 runtime: self.clone(),
262 })
263 }
264
265 pub async fn thread_fork(&self, thread_id: &str) -> Result<ThreadHandle, RpcError> {
266 let mut params = Map::<String, Value>::new();
267 params.insert("threadId".to_owned(), Value::String(thread_id.to_owned()));
268 let response = self
269 .call_validated(methods::THREAD_FORK, Value::Object(params))
270 .await?;
271 let forked = parse_thread_id(&response).ok_or_else(|| {
272 RpcError::InvalidRequest(format!(
273 "thread/fork missing thread id in result: {response}"
274 ))
275 })?;
276 Ok(ThreadHandle {
277 thread_id: forked,
278 runtime: self.clone(),
279 })
280 }
281
282 pub async fn thread_archive(&self, thread_id: &str) -> Result<(), RpcError> {
286 let mut params = Map::<String, Value>::new();
287 params.insert("threadId".to_owned(), Value::String(thread_id.to_owned()));
288 let _ = self
289 .call_validated(methods::THREAD_ARCHIVE, Value::Object(params))
290 .await?;
291 Ok(())
292 }
293
294 pub async fn thread_read(&self, p: ThreadReadParams) -> Result<ThreadReadResponse, RpcError> {
298 let params = serialize_params(methods::THREAD_READ, &p)?;
299 let response = self.call_validated(methods::THREAD_READ, params).await?;
300 deserialize_result(methods::THREAD_READ, response)
301 }
302
303 pub async fn thread_list(&self, p: ThreadListParams) -> Result<ThreadListResponse, RpcError> {
307 let params = serialize_params(methods::THREAD_LIST, &p)?;
308 let response = self.call_validated(methods::THREAD_LIST, params).await?;
309 deserialize_result(methods::THREAD_LIST, response)
310 }
311
312 pub async fn thread_loaded_list(
316 &self,
317 p: ThreadLoadedListParams,
318 ) -> Result<ThreadLoadedListResponse, RpcError> {
319 let params = serialize_params(methods::THREAD_LOADED_LIST, &p)?;
320 let response = self
321 .call_validated(methods::THREAD_LOADED_LIST, params)
322 .await?;
323 deserialize_result(methods::THREAD_LOADED_LIST, response)
324 }
325
326 pub async fn skills_list(&self, p: SkillsListParams) -> Result<SkillsListResponse, RpcError> {
330 let params = serialize_params(methods::SKILLS_LIST, &p)?;
331 let response = self.call_validated(methods::SKILLS_LIST, params).await?;
332 deserialize_result(methods::SKILLS_LIST, response)
333 }
334
335 pub async fn thread_rollback(
339 &self,
340 p: ThreadRollbackParams,
341 ) -> Result<ThreadRollbackResponse, RpcError> {
342 let params = serialize_params(methods::THREAD_ROLLBACK, &p)?;
343 let response = self
344 .call_validated(methods::THREAD_ROLLBACK, params)
345 .await?;
346 deserialize_result(methods::THREAD_ROLLBACK, response)
347 }
348
349 pub async fn turn_interrupt(&self, thread_id: &str, turn_id: &str) -> Result<(), RpcError> {
353 let _ = self
354 .call_validated(
355 methods::TURN_INTERRUPT,
356 interrupt_params(thread_id, turn_id),
357 )
358 .await?;
359 Ok(())
360 }
361
362 pub async fn turn_interrupt_with_timeout(
366 &self,
367 thread_id: &str,
368 turn_id: &str,
369 timeout_duration: Duration,
370 ) -> Result<(), RpcError> {
371 let _ = self
372 .call_validated_with_mode_and_timeout(
373 methods::TURN_INTERRUPT,
374 interrupt_params(thread_id, turn_id),
375 RpcValidationMode::KnownMethods,
376 timeout_duration,
377 )
378 .await?;
379 Ok(())
380 }
381}
382
383fn interrupt_params(thread_id: &str, turn_id: &str) -> Value {
384 let mut params = Map::<String, Value>::new();
385 params.insert("threadId".to_owned(), Value::String(thread_id.to_owned()));
386 params.insert("turnId".to_owned(), Value::String(turn_id.to_owned()));
387 Value::Object(params)
388}
389
390fn ensure_turn_input_not_empty(input: &[InputItem]) -> Result<(), RpcError> {
391 if input.is_empty() {
392 return Err(RpcError::InvalidRequest(
393 "turn input must not be empty".to_owned(),
394 ));
395 }
396 Ok(())
397}
398
399fn block_reason_to_rpc_error(r: BlockReason) -> RpcError {
402 RpcError::InvalidRequest(format!(
403 "blocked by hook '{}' at {:?}: {}",
404 r.hook_name, r.phase, r.message
405 ))
406}