llm_worker/
hook.rs

1//! Hook関連の型定義
2//!
3//! Worker層でのターン制御・介入に使用される型
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use thiserror::Error;
9
10// =============================================================================
11// Hook Event Kinds
12// =============================================================================
13
14pub trait HookEventKind: Send + Sync + 'static {
15    type Input;
16    type Output;
17}
18
19pub struct OnPromptSubmit;
20pub struct PreLlmRequest;
21pub struct PreToolCall;
22pub struct PostToolCall;
23pub struct OnTurnEnd;
24pub struct OnAbort;
25
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum OnPromptSubmitResult {
28    Continue,
29    Cancel(String),
30}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
33pub enum PreLlmRequestResult {
34    Continue,
35    Cancel(String),
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub enum PreToolCallResult {
40    Continue,
41    Skip,
42    Abort(String),
43    Pause,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum PostToolCallResult {
48    Continue,
49    Abort(String),
50}
51
52#[derive(Debug, Clone)]
53pub enum OnTurnEndResult {
54    Finish,
55    ContinueWithMessages(Vec<crate::Message>),
56    Paused,
57}
58
59use std::sync::Arc;
60
61use crate::tool::{Tool, ToolMeta};
62
63/// PreToolCall の入力コンテキスト
64pub struct ToolCallContext {
65    /// ツール呼び出し情報(改変可能)
66    pub call: ToolCall,
67    /// ツールメタ情報(不変)
68    pub meta: ToolMeta,
69    /// ツールインスタンス(状態アクセス用)
70    pub tool: Arc<dyn Tool>,
71}
72
73/// PostToolCall の入力コンテキスト
74pub struct PostToolCallContext {
75    /// ツール呼び出し情報
76    pub call: ToolCall,
77    /// ツール実行結果(改変可能)
78    pub result: ToolResult,
79    /// ツールメタ情報(不変)
80    pub meta: ToolMeta,
81    /// ツールインスタンス(状態アクセス用)
82    pub tool: Arc<dyn Tool>,
83}
84
85impl HookEventKind for OnPromptSubmit {
86    type Input = crate::Message;
87    type Output = OnPromptSubmitResult;
88}
89
90impl HookEventKind for PreLlmRequest {
91    type Input = Vec<crate::Message>;
92    type Output = PreLlmRequestResult;
93}
94
95impl HookEventKind for PreToolCall {
96    type Input = ToolCallContext;
97    type Output = PreToolCallResult;
98}
99
100impl HookEventKind for PostToolCall {
101    type Input = PostToolCallContext;
102    type Output = PostToolCallResult;
103}
104
105impl HookEventKind for OnTurnEnd {
106    type Input = Vec<crate::Message>;
107    type Output = OnTurnEndResult;
108}
109
110impl HookEventKind for OnAbort {
111    type Input = String;
112    type Output = ();
113}
114
115// =============================================================================
116// Tool Call / Result Types
117// =============================================================================
118
119/// ツール呼び出し情報
120///
121/// LLMからのToolUseブロックを表現し、Hook処理で改変可能
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct ToolCall {
124    /// ツール呼び出しID(レスポンスとの紐付けに使用)
125    pub id: String,
126    /// ツール名
127    pub name: String,
128    /// 入力引数(JSON)
129    pub input: Value,
130}
131
132/// ツール実行結果
133///
134/// ツール実行後の結果を表現し、Hook処理で改変可能
135#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct ToolResult {
137    /// 対応するツール呼び出しID
138    pub tool_use_id: String,
139    /// 結果コンテンツ
140    pub content: String,
141    /// エラーかどうか
142    #[serde(default)]
143    pub is_error: bool,
144}
145
146impl ToolResult {
147    /// 成功結果を作成
148    pub fn success(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
149        Self {
150            tool_use_id: tool_use_id.into(),
151            content: content.into(),
152            is_error: false,
153        }
154    }
155
156    /// エラー結果を作成
157    pub fn error(tool_use_id: impl Into<String>, content: impl Into<String>) -> Self {
158        Self {
159            tool_use_id: tool_use_id.into(),
160            content: content.into(),
161            is_error: true,
162        }
163    }
164}
165
166// =============================================================================
167// Hook Error
168// =============================================================================
169
170/// Hookエラー
171#[derive(Debug, Error)]
172pub enum HookError {
173    /// 処理が中断された
174    #[error("Aborted: {0}")]
175    Aborted(String),
176    /// 内部エラー
177    #[error("Hook error: {0}")]
178    Internal(String),
179}
180
181// =============================================================================
182// Hook Trait
183// =============================================================================
184
185/// Hookイベントの処理を行うトレイト
186///
187/// 各イベント種別は戻り値型が異なるため、`HookEventKind`を介して型を制約する。
188#[async_trait]
189pub trait Hook<E: HookEventKind>: Send + Sync {
190    async fn call(&self, input: &mut E::Input) -> Result<E::Output, HookError>;
191}
192
193// =============================================================================
194// Hook Registry
195// =============================================================================
196
197/// 全 Hook を保持するレジストリ
198///
199/// Worker 内部で使用され、各種 Hook を一括管理する。
200pub struct HookRegistry {
201    /// on_prompt_submit Hook
202    pub(crate) on_prompt_submit: Vec<Box<dyn Hook<OnPromptSubmit>>>,
203    /// pre_llm_request Hook
204    pub(crate) pre_llm_request: Vec<Box<dyn Hook<PreLlmRequest>>>,
205    /// pre_tool_call Hook
206    pub(crate) pre_tool_call: Vec<Box<dyn Hook<PreToolCall>>>,
207    /// post_tool_call Hook
208    pub(crate) post_tool_call: Vec<Box<dyn Hook<PostToolCall>>>,
209    /// on_turn_end Hook
210    pub(crate) on_turn_end: Vec<Box<dyn Hook<OnTurnEnd>>>,
211    /// on_abort Hook
212    pub(crate) on_abort: Vec<Box<dyn Hook<OnAbort>>>,
213}
214
215impl Default for HookRegistry {
216    fn default() -> Self {
217        Self::new()
218    }
219}
220
221impl HookRegistry {
222    /// 空の HookRegistry を作成
223    pub fn new() -> Self {
224        Self {
225            on_prompt_submit: Vec::new(),
226            pre_llm_request: Vec::new(),
227            pre_tool_call: Vec::new(),
228            post_tool_call: Vec::new(),
229            on_turn_end: Vec::new(),
230            on_abort: Vec::new(),
231        }
232    }
233}