llm_worker/
worker.rs

1use std::collections::HashMap;
2use std::marker::PhantomData;
3use std::sync::{Arc, Mutex};
4
5use futures::StreamExt;
6use tokio::sync::mpsc;
7use tracing::{debug, info, trace, warn};
8
9use crate::{
10    ContentPart, Message, MessageContent, Role,
11    hook::{
12        Hook, HookError, HookRegistry, OnAbort, OnPromptSubmit, OnPromptSubmitResult, OnTurnEnd,
13        OnTurnEndResult, PostToolCall, PostToolCallContext, PostToolCallResult, PreLlmRequest,
14        PreLlmRequestResult, PreToolCall, PreToolCallResult, ToolCall, ToolCallContext, ToolResult,
15    },
16    llm_client::{
17        ClientError, ConfigWarning, LlmClient, Request, RequestConfig,
18        ToolDefinition as LlmToolDefinition,
19    },
20    state::{CacheLocked, Mutable, WorkerState},
21    subscriber::{
22        ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
23        ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber,
24    },
25    timeline::{TextBlockCollector, Timeline, ToolCallCollector},
26    tool::{Tool, ToolDefinition, ToolError, ToolMeta},
27};
28
29// =============================================================================
30// Worker Error
31// =============================================================================
32
33/// Workerエラー
34#[derive(Debug, thiserror::Error)]
35pub enum WorkerError {
36    /// クライアントエラー
37    #[error("Client error: {0}")]
38    Client(#[from] ClientError),
39    /// ツールエラー
40    #[error("Tool error: {0}")]
41    Tool(#[from] ToolError),
42    /// Hookエラー
43    #[error("Hook error: {0}")]
44    Hook(#[from] HookError),
45    /// 処理が中断された
46    #[error("Aborted: {0}")]
47    Aborted(String),
48    /// Cancellation Tokenによって中断された
49    #[error("Cancelled")]
50    Cancelled,
51    /// 設定に関する警告(未サポートのオプション)
52    #[error("Config warnings: {}", .0.iter().map(|w| w.to_string()).collect::<Vec<_>>().join(", "))]
53    ConfigWarnings(Vec<ConfigWarning>),
54}
55
56/// ツール登録エラー
57#[derive(Debug, thiserror::Error)]
58pub enum ToolRegistryError {
59    /// 同名のツールが既に登録されている
60    #[error("Tool with name '{0}' already registered")]
61    DuplicateName(String),
62}
63
64// =============================================================================
65// Worker Config
66// =============================================================================
67
68/// Worker設定
69#[derive(Debug, Clone, Default)]
70pub struct WorkerConfig {
71    // 将来の拡張用(現在は空)
72    _private: (),
73}
74
75// =============================================================================
76// Worker Result Types
77// =============================================================================
78
79/// Workerの実行結果(ステータス)
80#[derive(Debug)]
81pub enum WorkerResult {
82    /// 完了(ユーザー入力待ち状態)
83    Finished,
84    /// 一時停止(再開可能)
85    Paused,
86}
87
88/// 内部用: ツール実行結果
89enum ToolExecutionResult {
90    Completed(Vec<ToolResult>),
91    Paused,
92}
93
94// =============================================================================
95// ターン制御用コールバック保持
96// =============================================================================
97
98/// ターンイベントを通知するためのコールバック (型消去)
99trait TurnNotifier: Send + Sync {
100    fn on_turn_start(&self, turn: usize);
101    fn on_turn_end(&self, turn: usize);
102}
103
104struct SubscriberTurnNotifier<S: WorkerSubscriber + 'static> {
105    subscriber: Arc<Mutex<S>>,
106}
107
108impl<S: WorkerSubscriber + 'static> TurnNotifier for SubscriberTurnNotifier<S> {
109    fn on_turn_start(&self, turn: usize) {
110        if let Ok(mut s) = self.subscriber.lock() {
111            s.on_turn_start(turn);
112        }
113    }
114
115    fn on_turn_end(&self, turn: usize) {
116        if let Ok(mut s) = self.subscriber.lock() {
117            s.on_turn_end(turn);
118        }
119    }
120}
121
122// =============================================================================
123// Worker
124// =============================================================================
125
126/// LLMとの対話を管理する中心コンポーネント
127///
128/// ユーザーからの入力を受け取り、LLMにリクエストを送信し、
129/// ツール呼び出しがあれば自動的に実行してターンを進行させます。
130///
131/// # 状態遷移(Type-state)
132///
133/// - [`Mutable`]: 初期状態。システムプロンプトや履歴を自由に編集可能。
134/// - [`CacheLocked`]: キャッシュ保護状態。`lock()`で遷移。前方コンテキストは不変。
135///
136/// # Examples
137///
138/// ```ignore
139/// use llm_worker::{Worker, Message};
140///
141/// // Workerを作成してツールを登録
142/// let mut worker = Worker::new(client)
143///     .system_prompt("You are a helpful assistant.");
144/// worker.register_tool(my_tool);
145///
146/// // 対話を実行
147/// let history = worker.run("Hello!").await?;
148/// ```
149///
150/// # キャッシュ保護が必要な場合
151///
152/// ```ignore
153/// let mut worker = Worker::new(client)
154///     .system_prompt("...");
155///
156/// // 履歴を設定後、ロックしてキャッシュを保護
157/// let mut locked = worker.lock();
158/// locked.run("user input").await?;
159/// ```
160pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
161    /// LLMクライアント
162    client: C,
163    /// イベントタイムライン
164    timeline: Timeline,
165    /// テキストブロックコレクター(Timeline用ハンドラ)
166    text_block_collector: TextBlockCollector,
167    /// ツールコールコレクター(Timeline用ハンドラ)
168    tool_call_collector: ToolCallCollector,
169    /// 登録されたツール (meta, instance)
170    tools: HashMap<String, (ToolMeta, Arc<dyn Tool>)>,
171    /// Hook レジストリ
172    hooks: HookRegistry,
173    /// システムプロンプト
174    system_prompt: Option<String>,
175    /// メッセージ履歴(Workerが所有)
176    history: Vec<Message>,
177    /// ロック時点での履歴長(CacheLocked状態でのみ意味を持つ)
178    locked_prefix_len: usize,
179    /// ターンカウント
180    turn_count: usize,
181    /// ターン通知用のコールバック
182    turn_notifiers: Vec<Box<dyn TurnNotifier>>,
183    /// リクエスト設定(max_tokens, temperature等)
184    request_config: RequestConfig,
185    /// 前回の実行が中断されたかどうか
186    last_run_interrupted: bool,
187    /// キャンセル通知用チャネル(実行中断用)
188    cancel_tx: mpsc::Sender<()>,
189    cancel_rx: mpsc::Receiver<()>,
190    /// 状態マーカー
191    _state: PhantomData<S>,
192}
193
194// =============================================================================
195// 共通実装(全状態で利用可能)
196// =============================================================================
197
198impl<C: LlmClient, S: WorkerState> Worker<C, S> {
199    fn reset_interruption_state(&mut self) {
200        self.last_run_interrupted = false;
201    }
202
203    /// ターンを実行
204    ///
205    /// 新しいユーザーメッセージを履歴に追加し、LLMにリクエストを送信する。
206    /// ツール呼び出しがある場合は自動的にループする。
207    pub async fn run(
208        &mut self,
209        user_input: impl Into<String>,
210    ) -> Result<WorkerResult, WorkerError> {
211        self.reset_interruption_state();
212        // Hook: on_prompt_submit
213        let mut user_message = Message::user(user_input);
214        let result = self.run_on_prompt_submit_hooks(&mut user_message).await;
215        let result = match result {
216            Ok(value) => value,
217            Err(err) => return self.finalize_interruption(Err(err)).await,
218        };
219        match result {
220            OnPromptSubmitResult::Cancel(reason) => {
221                self.last_run_interrupted = true;
222                return self.finalize_interruption(Err(WorkerError::Aborted(reason))).await;
223            }
224            OnPromptSubmitResult::Continue => {}
225        }
226        self.history.push(user_message);
227        let result = self.run_turn_loop().await;
228        self.finalize_interruption(result).await
229    }
230
231    fn drain_cancel_queue(&mut self) {
232        use tokio::sync::mpsc::error::TryRecvError;
233        loop {
234            match self.cancel_rx.try_recv() {
235                Ok(()) => continue,
236                Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
237            }
238        }
239    }
240
241    fn try_cancelled(&mut self) -> bool {
242        use tokio::sync::mpsc::error::TryRecvError;
243        match self.cancel_rx.try_recv() {
244            Ok(()) => true,
245            Err(TryRecvError::Empty) => false,
246            Err(TryRecvError::Disconnected) => true,
247        }
248    }
249
250    /// イベント購読者を登録する
251    ///
252    /// 登録したSubscriberは、LLMからのストリーミングイベントを
253    /// リアルタイムで受信できます。UIへのストリーム表示などに利用します。
254    ///
255    /// # 受信できるイベント
256    ///
257    /// - **ブロックイベント**: `on_text_block`, `on_tool_use_block`
258    /// - **メタイベント**: `on_usage`, `on_status`, `on_error`
259    /// - **完了イベント**: `on_text_complete`, `on_tool_call_complete`
260    /// - **ターン制御**: `on_turn_start`, `on_turn_end`
261    ///
262    /// # Examples
263    ///
264    /// ```ignore
265    /// use llm_worker::{Worker, WorkerSubscriber, TextBlockEvent};
266    ///
267    /// struct MyPrinter;
268    /// impl WorkerSubscriber for MyPrinter {
269    ///     type TextBlockScope = ();
270    ///     type ToolUseBlockScope = ();
271    ///
272    ///     fn on_text_block(&mut self, _: &mut (), event: &TextBlockEvent) {
273    ///         if let TextBlockEvent::Delta(text) = event {
274    ///             print!("{}", text);
275    ///         }
276    ///     }
277    /// }
278    ///
279    /// worker.subscribe(MyPrinter);
280    /// ```
281    pub fn subscribe<Sub: WorkerSubscriber + 'static>(&mut self, subscriber: Sub) {
282        let subscriber = Arc::new(Mutex::new(subscriber));
283
284        // TextBlock用ハンドラを登録
285        self.timeline
286            .on_text_block(TextBlockSubscriberAdapter::new(subscriber.clone()));
287
288        // ToolUseBlock用ハンドラを登録
289        self.timeline
290            .on_tool_use_block(ToolUseBlockSubscriberAdapter::new(subscriber.clone()));
291
292        // Meta系ハンドラを登録
293        self.timeline
294            .on_usage(UsageSubscriberAdapter::new(subscriber.clone()));
295        self.timeline
296            .on_status(StatusSubscriberAdapter::new(subscriber.clone()));
297        self.timeline
298            .on_error(ErrorSubscriberAdapter::new(subscriber.clone()));
299
300        // ターン制御用コールバックを登録
301        self.turn_notifiers
302            .push(Box::new(SubscriberTurnNotifier { subscriber }));
303    }
304
305    /// ツールを登録する
306    ///
307    /// 登録されたツールはLLMからの呼び出しで自動的に実行されます。
308    /// 同名のツールを登録するとエラーになります。
309    ///
310    /// # Examples
311    ///
312    /// ```ignore
313    /// use llm_worker::tool::{ToolMeta, ToolDefinition, Tool};
314    /// use std::sync::Arc;
315    ///
316    /// let def: ToolDefinition = Arc::new(|| {
317    ///     (ToolMeta::new("search").description("..."), Arc::new(MyTool) as Arc<dyn Tool>)
318    /// });
319    /// worker.register_tool(def)?;
320    /// ```
321    pub fn register_tool(&mut self, factory: ToolDefinition) -> Result<(), ToolRegistryError> {
322        let (meta, instance) = factory();
323        if self.tools.contains_key(&meta.name) {
324            return Err(ToolRegistryError::DuplicateName(meta.name.clone()));
325        }
326        self.tools.insert(meta.name.clone(), (meta, instance));
327        Ok(())
328    }
329
330    /// 複数のツールを登録
331    pub fn register_tools(
332        &mut self,
333        factories: impl IntoIterator<Item = ToolDefinition>,
334    ) -> Result<(), ToolRegistryError> {
335        for factory in factories {
336            self.register_tool(factory)?;
337        }
338        Ok(())
339    }
340
341    /// on_prompt_submit Hookを追加する
342    ///
343    /// `run()` でユーザーメッセージを受け取った直後に呼び出される。
344    pub fn add_on_prompt_submit_hook(&mut self, hook: impl Hook<OnPromptSubmit> + 'static) {
345        self.hooks.on_prompt_submit.push(Box::new(hook));
346    }
347
348    /// pre_llm_request Hookを追加する
349    ///
350    /// 各ターンのLLMリクエスト送信前に呼び出される。
351    pub fn add_pre_llm_request_hook(&mut self, hook: impl Hook<PreLlmRequest> + 'static) {
352        self.hooks.pre_llm_request.push(Box::new(hook));
353    }
354
355    /// pre_tool_call Hookを追加する
356    pub fn add_pre_tool_call_hook(&mut self, hook: impl Hook<PreToolCall> + 'static) {
357        self.hooks.pre_tool_call.push(Box::new(hook));
358    }
359
360    /// post_tool_call Hookを追加する
361    pub fn add_post_tool_call_hook(&mut self, hook: impl Hook<PostToolCall> + 'static) {
362        self.hooks.post_tool_call.push(Box::new(hook));
363    }
364
365    /// on_turn_end Hookを追加する
366    pub fn add_on_turn_end_hook(&mut self, hook: impl Hook<OnTurnEnd> + 'static) {
367        self.hooks.on_turn_end.push(Box::new(hook));
368    }
369
370    /// on_abort Hookを追加する
371    pub fn add_on_abort_hook(&mut self, hook: impl Hook<OnAbort> + 'static) {
372        self.hooks.on_abort.push(Box::new(hook));
373    }
374
375    /// タイムラインへの可変参照を取得(追加ハンドラ登録用)
376    pub fn timeline_mut(&mut self) -> &mut Timeline {
377        &mut self.timeline
378    }
379
380    /// 履歴への参照を取得
381    pub fn history(&self) -> &[Message] {
382        &self.history
383    }
384
385    /// システムプロンプトへの参照を取得
386    pub fn get_system_prompt(&self) -> Option<&str> {
387        self.system_prompt.as_deref()
388    }
389
390    /// 現在のターンカウントを取得
391    pub fn turn_count(&self) -> usize {
392        self.turn_count
393    }
394
395    /// 現在のリクエスト設定への参照を取得
396    pub fn request_config(&self) -> &RequestConfig {
397        &self.request_config
398    }
399
400    /// 最大トークン数を設定
401    ///
402    /// この設定はキャッシュロックとは独立しており、各リクエストに適用されます。
403    ///
404    /// # Examples
405    ///
406    /// ```ignore
407    /// worker.set_max_tokens(4096);
408    /// ```
409    pub fn set_max_tokens(&mut self, max_tokens: u32) {
410        self.request_config.max_tokens = Some(max_tokens);
411    }
412
413    /// temperatureを設定
414    ///
415    /// 0.0から1.0(または2.0)の範囲で設定します。
416    /// 低い値はより決定的な出力を、高い値はより多様な出力を生成します。
417    ///
418    /// # Examples
419    ///
420    /// ```ignore
421    /// worker.set_temperature(0.7);
422    /// ```
423    pub fn set_temperature(&mut self, temperature: f32) {
424        self.request_config.temperature = Some(temperature);
425    }
426
427    /// top_pを設定(nucleus sampling)
428    ///
429    /// # Examples
430    ///
431    /// ```ignore
432    /// worker.set_top_p(0.9);
433    /// ```
434    pub fn set_top_p(&mut self, top_p: f32) {
435        self.request_config.top_p = Some(top_p);
436    }
437
438    /// top_kを設定
439    ///
440    /// トークン選択時に考慮する上位k個のトークンを指定します。
441    ///
442    /// # Examples
443    ///
444    /// ```ignore
445    /// worker.set_top_k(40);
446    /// ```
447    pub fn set_top_k(&mut self, top_k: u32) {
448        self.request_config.top_k = Some(top_k);
449    }
450
451    /// ストップシーケンスを追加
452    ///
453    /// # Examples
454    ///
455    /// ```ignore
456    /// worker.add_stop_sequence("\n\n");
457    /// ```
458    pub fn add_stop_sequence(&mut self, sequence: impl Into<String>) {
459        self.request_config.stop_sequences.push(sequence.into());
460    }
461
462    /// ストップシーケンスをクリア
463    pub fn clear_stop_sequences(&mut self) {
464        self.request_config.stop_sequences.clear();
465    }
466
467    /// キャンセル通知用Senderを取得する
468    pub fn cancel_sender(&self) -> mpsc::Sender<()> {
469        self.cancel_tx.clone()
470    }
471
472    /// リクエスト設定を一括で設定
473    pub fn set_request_config(&mut self, config: RequestConfig) {
474        self.request_config = config;
475    }
476
477    /// 実行をキャンセルする
478    ///
479    /// 現在実行中のストリーミングやツール実行を中断します。
480    /// 次のイベントループのチェックポイントでWorkerError::Cancelledが返されます。
481    ///
482    /// # Examples
483    ///
484    /// ```ignore
485    /// use std::sync::Arc;
486    /// let worker = Arc::new(Mutex::new(Worker::new(client)));
487    ///
488    /// // 別スレッドで実行
489    /// let worker_clone = worker.clone();
490    /// tokio::spawn(async move {
491    ///     let mut w = worker_clone.lock().unwrap();
492    ///     w.run("Long task...").await
493    /// });
494    ///
495    /// // キャンセル
496    /// worker.lock().unwrap().cancel();
497    /// ```
498    pub fn cancel(&self) {
499        let _ = self.cancel_tx.try_send(());
500    }
501
502    /// キャンセルされているかチェック
503    pub fn is_cancelled(&mut self) -> bool {
504        self.try_cancelled()
505    }
506
507    /// 前回の実行が中断されたかどうか
508    pub fn last_run_interrupted(&self) -> bool {
509        self.last_run_interrupted
510    }
511
512    /// 登録されたツールからLLM用ToolDefinitionのリストを生成
513    fn build_tool_definitions(&self) -> Vec<LlmToolDefinition> {
514        self.tools
515            .values()
516            .map(|(meta, _)| {
517                LlmToolDefinition::new(&meta.name)
518                    .description(&meta.description)
519                    .input_schema(meta.input_schema.clone())
520            })
521            .collect()
522    }
523
524    /// テキストブロックとツール呼び出しからアシスタントメッセージを構築
525    fn build_assistant_message(
526        &self,
527        text_blocks: &[String],
528        tool_calls: &[ToolCall],
529    ) -> Option<Message> {
530        // テキストもツール呼び出しもない場合はNone
531        if text_blocks.is_empty() && tool_calls.is_empty() {
532            return None;
533        }
534
535        // テキストのみの場合はシンプルなテキストメッセージ
536        if tool_calls.is_empty() {
537            let text = text_blocks.join("");
538            return Some(Message::assistant(text));
539        }
540
541        // ツール呼び出しがある場合は Parts として構築
542        let mut parts = Vec::new();
543
544        // テキストパーツを追加
545        for text in text_blocks {
546            if !text.is_empty() {
547                parts.push(ContentPart::Text { text: text.clone() });
548            }
549        }
550
551        // ツール呼び出しパーツを追加
552        for call in tool_calls {
553            parts.push(ContentPart::ToolUse {
554                id: call.id.clone(),
555                name: call.name.clone(),
556                input: call.input.clone(),
557            });
558        }
559
560        Some(Message {
561            role: Role::Assistant,
562            content: MessageContent::Parts(parts),
563        })
564    }
565
566    /// リクエストを構築
567    fn build_request(
568        &self,
569        tool_definitions: &[LlmToolDefinition],
570        context: &[Message],
571    ) -> Request {
572        let mut request = Request::new();
573
574        // システムプロンプトを設定
575        if let Some(ref system) = self.system_prompt {
576            request = request.system(system);
577        }
578
579        // メッセージを追加
580        for msg in context {
581            // Message から llm_client::Message への変換
582            request = request.message(crate::llm_client::Message {
583                role: match msg.role {
584                    Role::User => crate::llm_client::Role::User,
585                    Role::Assistant => crate::llm_client::Role::Assistant,
586                },
587                content: match &msg.content {
588                    MessageContent::Text(t) => crate::llm_client::MessageContent::Text(t.clone()),
589                    MessageContent::ToolResult {
590                        tool_use_id,
591                        content,
592                    } => crate::llm_client::MessageContent::ToolResult {
593                        tool_use_id: tool_use_id.clone(),
594                        content: content.clone(),
595                    },
596                    MessageContent::Parts(parts) => crate::llm_client::MessageContent::Parts(
597                        parts
598                            .iter()
599                            .map(|p| match p {
600                                ContentPart::Text { text } => {
601                                    crate::llm_client::ContentPart::Text { text: text.clone() }
602                                }
603                                ContentPart::ToolUse { id, name, input } => {
604                                    crate::llm_client::ContentPart::ToolUse {
605                                        id: id.clone(),
606                                        name: name.clone(),
607                                        input: input.clone(),
608                                    }
609                                }
610                                ContentPart::ToolResult {
611                                    tool_use_id,
612                                    content,
613                                } => crate::llm_client::ContentPart::ToolResult {
614                                    tool_use_id: tool_use_id.clone(),
615                                    content: content.clone(),
616                                },
617                            })
618                            .collect(),
619                    ),
620                },
621            });
622        }
623
624        // ツール定義を追加
625        for tool_def in tool_definitions {
626            request = request.tool(tool_def.clone());
627        }
628
629        // リクエスト設定を適用
630        request = request.config(self.request_config.clone());
631
632        request
633    }
634
635    /// Hooks: on_prompt_submit
636    ///
637    /// `run()` でユーザーメッセージを受け取った直後に呼び出される(最初だけ)。
638    async fn run_on_prompt_submit_hooks(
639        &self,
640        message: &mut Message,
641    ) -> Result<OnPromptSubmitResult, WorkerError> {
642        for hook in &self.hooks.on_prompt_submit {
643            let result = hook.call(message).await?;
644            match result {
645                OnPromptSubmitResult::Continue => continue,
646                OnPromptSubmitResult::Cancel(reason) => {
647                    return Ok(OnPromptSubmitResult::Cancel(reason));
648                }
649            }
650        }
651        Ok(OnPromptSubmitResult::Continue)
652    }
653
654    /// Hooks: pre_llm_request
655    ///
656    /// 各ターンのLLMリクエスト送信前に呼び出される(毎ターン)。
657    async fn run_pre_llm_request_hooks(
658        &self,
659    ) -> Result<(PreLlmRequestResult, Vec<Message>), WorkerError> {
660        let mut temp_context = self.history.clone();
661        for hook in &self.hooks.pre_llm_request {
662            let result = hook.call(&mut temp_context).await?;
663            match result {
664                PreLlmRequestResult::Continue => continue,
665                PreLlmRequestResult::Cancel(reason) => {
666                    return Ok((PreLlmRequestResult::Cancel(reason), temp_context));
667                }
668            }
669        }
670        Ok((PreLlmRequestResult::Continue, temp_context))
671    }
672
673    /// Hooks: on_turn_end
674    async fn run_on_turn_end_hooks(&self) -> Result<OnTurnEndResult, WorkerError> {
675        let mut temp_messages = self.history.clone();
676        for hook in &self.hooks.on_turn_end {
677            let result = hook.call(&mut temp_messages).await?;
678            match result {
679                OnTurnEndResult::Finish => continue,
680                OnTurnEndResult::ContinueWithMessages(msgs) => {
681                    return Ok(OnTurnEndResult::ContinueWithMessages(msgs));
682                }
683                OnTurnEndResult::Paused => return Ok(OnTurnEndResult::Paused),
684            }
685        }
686        Ok(OnTurnEndResult::Finish)
687    }
688
689    /// Hooks: on_abort
690    async fn run_on_abort_hooks(&self, reason: &str) -> Result<(), WorkerError> {
691        let mut reason = reason.to_string();
692        for hook in &self.hooks.on_abort {
693            hook.call(&mut reason).await?;
694        }
695        Ok(())
696    }
697
698    async fn finalize_interruption<T>(
699        &mut self,
700        result: Result<T, WorkerError>,
701    ) -> Result<T, WorkerError> {
702        match result {
703            Ok(value) => Ok(value),
704            Err(err) => {
705                self.last_run_interrupted = true;
706                let reason = match &err {
707                    WorkerError::Aborted(reason) => reason.clone(),
708                    WorkerError::Cancelled => "Cancelled".to_string(),
709                    _ => err.to_string(),
710                };
711                if let Err(hook_err) = self.run_on_abort_hooks(&reason).await {
712                    self.last_run_interrupted = true;
713                    return Err(hook_err);
714                }
715                Err(err)
716            }
717        }
718    }
719
720    /// 未実行のツール呼び出しがあるかチェック(Pauseからの復帰用)
721    fn get_pending_tool_calls(&self) -> Option<Vec<ToolCall>> {
722        let last_msg = self.history.last()?;
723        if last_msg.role != Role::Assistant {
724            return None;
725        }
726
727        let mut calls = Vec::new();
728        if let MessageContent::Parts(parts) = &last_msg.content {
729            for part in parts {
730                if let ContentPart::ToolUse { id, name, input } = part {
731                    calls.push(ToolCall {
732                        id: id.clone(),
733                        name: name.clone(),
734                        input: input.clone(),
735                    });
736                }
737            }
738        }
739
740        if calls.is_empty() { None } else { Some(calls) }
741    }
742
743    /// ツールを並列実行
744    ///
745    /// 全てのツールに対してpre_tool_callフックを実行後、
746    /// 許可されたツールを並列に実行し、結果にpost_tool_callフックを適用する。
747    async fn execute_tools(
748        &mut self,
749        tool_calls: Vec<ToolCall>,
750    ) -> Result<ToolExecutionResult, WorkerError> {
751        use futures::future::join_all;
752
753        // ツール呼び出しIDから (ToolCall, Meta, Tool) へのマップ
754        // PostToolCallフックで必要になるため保持する
755        let mut call_info_map = HashMap::new();
756
757        // Phase 1: pre_tool_call フックを適用(スキップ/中断を判定)
758        let mut approved_calls = Vec::new();
759        for mut tool_call in tool_calls {
760            // ツール定義を取得
761            if let Some((meta, tool)) = self.tools.get(&tool_call.name) {
762                // コンテキストを作成
763                let mut context = ToolCallContext {
764                    call: tool_call.clone(),
765                    meta: meta.clone(),
766                    tool: tool.clone(),
767                };
768
769                let mut skip = false;
770                for hook in &self.hooks.pre_tool_call {
771                    let result = hook
772                        .call(&mut context)
773                        .await
774                        .inspect_err(|_| self.last_run_interrupted = true)?;
775                    match result {
776                        PreToolCallResult::Continue => {}
777                        PreToolCallResult::Skip => {
778                            skip = true;
779                            break;
780                        }
781                        PreToolCallResult::Abort(reason) => {
782                            self.last_run_interrupted = true;
783                            return Err(WorkerError::Aborted(reason));
784                        }
785                        PreToolCallResult::Pause => {
786                            self.last_run_interrupted = true;
787                            return Ok(ToolExecutionResult::Paused);
788                        }
789                    }
790                }
791
792                // フックで変更された内容を反映
793                tool_call = context.call;
794
795                // マップに保存(実行する場合のみ)
796                if !skip {
797                    call_info_map.insert(
798                        tool_call.id.clone(),
799                        (tool_call.clone(), meta.clone(), tool.clone()),
800                    );
801                    approved_calls.push(tool_call);
802                }
803            } else {
804                // 未知のツールはそのまま承認リストに入れる(実行時にエラーになる)
805                // Hookは適用しない(Metaがないため)
806                approved_calls.push(tool_call);
807            }
808        }
809
810        // Phase 2: 許可されたツールを並列実行(キャンセル可能)
811        let futures: Vec<_> = approved_calls
812            .into_iter()
813            .map(|tool_call| {
814                let tools = &self.tools;
815                async move {
816                    if let Some((_, tool)) = tools.get(&tool_call.name) {
817                        let input_json =
818                            serde_json::to_string(&tool_call.input).unwrap_or_default();
819                        match tool.execute(&input_json).await {
820                            Ok(content) => ToolResult::success(&tool_call.id, content),
821                            Err(e) => ToolResult::error(&tool_call.id, e.to_string()),
822                        }
823                    } else {
824                        ToolResult::error(
825                            &tool_call.id,
826                            format!("Tool '{}' not found", tool_call.name),
827                        )
828                    }
829                }
830            })
831            .collect();
832
833        // ツール実行をキャンセル可能にする
834        let mut results = tokio::select! {
835            results = join_all(futures) => results,
836            cancel = self.cancel_rx.recv() => {
837                if cancel.is_some() {
838                    info!("Tool execution cancelled");
839                }
840                self.timeline.abort_current_block();
841                self.last_run_interrupted = true;
842                return Err(WorkerError::Cancelled);
843            }
844        };
845
846        // Phase 3: post_tool_call フックを適用
847        for tool_result in &mut results {
848            // 保存しておいた情報を取得
849            if let Some((tool_call, meta, tool)) = call_info_map.get(&tool_result.tool_use_id) {
850                let mut context = PostToolCallContext {
851                    call: tool_call.clone(),
852                    result: tool_result.clone(),
853                    meta: meta.clone(),
854                    tool: tool.clone(),
855                };
856
857                for hook in &self.hooks.post_tool_call {
858                    let result = hook
859                        .call(&mut context)
860                        .await
861                        .inspect_err(|_| self.last_run_interrupted = true)?;
862                    match result {
863                        PostToolCallResult::Continue => {}
864                        PostToolCallResult::Abort(reason) => {
865                            self.last_run_interrupted = true;
866                            return Err(WorkerError::Aborted(reason));
867                        }
868                    }
869                }
870                // フックで変更された結果を反映
871                *tool_result = context.result;
872            }
873        }
874
875        Ok(ToolExecutionResult::Completed(results))
876    }
877
878    /// 内部で使用するターン実行ロジック
879    async fn run_turn_loop(&mut self) -> Result<WorkerResult, WorkerError> {
880        self.reset_interruption_state();
881        self.drain_cancel_queue();
882        let tool_definitions = self.build_tool_definitions();
883
884        info!(
885            message_count = self.history.len(),
886            tool_count = tool_definitions.len(),
887            "Starting worker run"
888        );
889
890        // Resume check: Pending tool calls
891        if let Some(tool_calls) = self.get_pending_tool_calls() {
892            info!("Resuming pending tool calls");
893            match self.execute_tools(tool_calls).await {
894                Ok(ToolExecutionResult::Paused) => {
895                    self.last_run_interrupted = true;
896                    return Ok(WorkerResult::Paused);
897                }
898                Ok(ToolExecutionResult::Completed(results)) => {
899                    for result in results {
900                        self.history
901                            .push(Message::tool_result(&result.tool_use_id, &result.content));
902                    }
903                    // Continue to loop
904                }
905                Err(err) => {
906                    self.last_run_interrupted = true;
907                    return Err(err);
908                }
909            }
910        }
911
912        loop {
913            // キャンセルチェック
914            if self.try_cancelled() {
915                info!("Execution cancelled");
916                self.timeline.abort_current_block();
917                self.last_run_interrupted = true;
918                return Err(WorkerError::Cancelled);
919            }
920
921            // ターン開始を通知
922            let current_turn = self.turn_count;
923            debug!(turn = current_turn, "Turn start");
924            for notifier in &self.turn_notifiers {
925                notifier.on_turn_start(current_turn);
926            }
927
928            // Hook: pre_llm_request
929            let (control, request_context) = self
930                .run_pre_llm_request_hooks()
931                .await
932                .inspect_err(|_| self.last_run_interrupted = true)?;
933            match control {
934                PreLlmRequestResult::Cancel(reason) => {
935                    info!(reason = %reason, "Aborted by hook");
936                    for notifier in &self.turn_notifiers {
937                        notifier.on_turn_end(current_turn);
938                    }
939                    self.last_run_interrupted = true;
940                    return Err(WorkerError::Aborted(reason));
941                }
942                PreLlmRequestResult::Continue => {}
943            }
944
945            // リクエスト構築
946            let request = self.build_request(&tool_definitions, &request_context);
947            debug!(
948                message_count = request.messages.len(),
949                tool_count = request.tools.len(),
950                has_system = request.system_prompt.is_some(),
951                "Sending request to LLM"
952            );
953
954            // ストリーム処理
955            debug!("Starting stream...");
956            let mut event_count = 0;
957
958            // ストリームを取得(キャンセル可能)
959            let mut stream = tokio::select! {
960                stream_result = self.client.stream(request) => stream_result
961                    .inspect_err(|_| self.last_run_interrupted = true)?,
962                cancel = self.cancel_rx.recv() => {
963                    if cancel.is_some() {
964                        info!("Cancelled before stream started");
965                    }
966                    self.timeline.abort_current_block();
967                    self.last_run_interrupted = true;
968                    return Err(WorkerError::Cancelled);
969                }
970            };
971
972            loop {
973                tokio::select! {
974                    // ストリームからイベントを受信
975                    event_result = stream.next() => {
976                        match event_result {
977                            Some(result) => {
978                                match &result {
979                                    Ok(event) => {
980                                        trace!(event = ?event, "Received event");
981                                        event_count += 1;
982                                    }
983                                    Err(e) => {
984                                        warn!(error = %e, "Stream error");
985                                    }
986                                }
987                                let event = result
988                                    .inspect_err(|_| self.last_run_interrupted = true)?;
989                                let timeline_event: crate::timeline::event::Event = event.into();
990                                self.timeline.dispatch(&timeline_event);
991                            }
992                            None => break, // ストリーム終了
993                        }
994                    }
995                    // キャンセル待機
996                    cancel = self.cancel_rx.recv() => {
997                        if cancel.is_some() {
998                            info!("Stream cancelled");
999                        }
1000                        self.timeline.abort_current_block();
1001                        self.last_run_interrupted = true;
1002                        return Err(WorkerError::Cancelled);
1003                    }
1004                }
1005            }
1006            debug!(event_count = event_count, "Stream completed");
1007
1008            // ターン終了を通知
1009            for notifier in &self.turn_notifiers {
1010                notifier.on_turn_end(current_turn);
1011            }
1012            self.turn_count += 1;
1013
1014            // 収集結果を取得
1015            let text_blocks = self.text_block_collector.take_collected();
1016            let tool_calls = self.tool_call_collector.take_collected();
1017
1018            // アシスタントメッセージを履歴に追加
1019            let assistant_message = self.build_assistant_message(&text_blocks, &tool_calls);
1020            if let Some(msg) = assistant_message {
1021                self.history.push(msg);
1022            }
1023
1024            if tool_calls.is_empty() {
1025                // ツール呼び出しなし → ターン終了判定
1026                let turn_result = self
1027                    .run_on_turn_end_hooks()
1028                    .await
1029                    .inspect_err(|_| self.last_run_interrupted = true)?;
1030                match turn_result {
1031                    OnTurnEndResult::Finish => {
1032                        self.last_run_interrupted = false;
1033                        return Ok(WorkerResult::Finished);
1034                    }
1035                    OnTurnEndResult::ContinueWithMessages(additional) => {
1036                        self.history.extend(additional);
1037                        continue;
1038                    }
1039                    OnTurnEndResult::Paused => {
1040                        self.last_run_interrupted = true;
1041                        return Ok(WorkerResult::Paused);
1042                    }
1043                }
1044            }
1045
1046            // ツール実行
1047            match self.execute_tools(tool_calls).await {
1048                Ok(ToolExecutionResult::Paused) => {
1049                    self.last_run_interrupted = true;
1050                    return Ok(WorkerResult::Paused);
1051                }
1052                Ok(ToolExecutionResult::Completed(results)) => {
1053                    for result in results {
1054                        self.history
1055                            .push(Message::tool_result(&result.tool_use_id, &result.content));
1056                    }
1057                }
1058                Err(err) => {
1059                    self.last_run_interrupted = true;
1060                    return Err(err);
1061                }
1062            }
1063        }
1064    }
1065
1066    /// 実行を再開(Pause状態からの復帰)
1067    ///
1068    /// 新しいユーザーメッセージを履歴に追加せず、現在の状態からターン処理を再開する。
1069    pub async fn resume(&mut self) -> Result<WorkerResult, WorkerError> {
1070        self.reset_interruption_state();
1071        let result = self.run_turn_loop().await;
1072        self.finalize_interruption(result).await
1073    }
1074}
1075
1076// =============================================================================
1077// Mutable状態専用の実装
1078// =============================================================================
1079
1080impl<C: LlmClient> Worker<C, Mutable> {
1081    /// 新しいWorkerを作成(Mutable状態)
1082    pub fn new(client: C) -> Self {
1083        let text_block_collector = TextBlockCollector::new();
1084        let tool_call_collector = ToolCallCollector::new();
1085        let mut timeline = Timeline::new();
1086        let (cancel_tx, cancel_rx) = mpsc::channel(1);
1087
1088        // コレクターをTimelineに登録
1089        timeline.on_text_block(text_block_collector.clone());
1090        timeline.on_tool_use_block(tool_call_collector.clone());
1091
1092        Self {
1093            client,
1094            timeline,
1095            text_block_collector,
1096            tool_call_collector,
1097            tools: HashMap::new(),
1098            hooks: HookRegistry::new(),
1099            system_prompt: None,
1100            history: Vec::new(),
1101            locked_prefix_len: 0,
1102            turn_count: 0,
1103            turn_notifiers: Vec::new(),
1104            request_config: RequestConfig::default(),
1105            last_run_interrupted: false,
1106            cancel_tx,
1107            cancel_rx,
1108            _state: PhantomData,
1109        }
1110    }
1111
1112    /// システムプロンプトを設定(ビルダーパターン)
1113    pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
1114        self.system_prompt = Some(prompt.into());
1115        self
1116    }
1117
1118    /// システムプロンプトを設定(可変参照版)
1119    pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
1120        self.system_prompt = Some(prompt.into());
1121    }
1122
1123    /// 最大トークン数を設定(ビルダーパターン)
1124    ///
1125    /// # Examples
1126    ///
1127    /// ```ignore
1128    /// let worker = Worker::new(client)
1129    ///     .system_prompt("You are a helpful assistant.")
1130    ///     .max_tokens(4096);
1131    /// ```
1132    pub fn max_tokens(mut self, max_tokens: u32) -> Self {
1133        self.request_config.max_tokens = Some(max_tokens);
1134        self
1135    }
1136
1137    /// temperatureを設定(ビルダーパターン)
1138    ///
1139    /// # Examples
1140    ///
1141    /// ```ignore
1142    /// let worker = Worker::new(client)
1143    ///     .temperature(0.7);
1144    /// ```
1145    pub fn temperature(mut self, temperature: f32) -> Self {
1146        self.request_config.temperature = Some(temperature);
1147        self
1148    }
1149
1150    /// top_pを設定(ビルダーパターン)
1151    pub fn top_p(mut self, top_p: f32) -> Self {
1152        self.request_config.top_p = Some(top_p);
1153        self
1154    }
1155
1156    /// top_kを設定(ビルダーパターン)
1157    pub fn top_k(mut self, top_k: u32) -> Self {
1158        self.request_config.top_k = Some(top_k);
1159        self
1160    }
1161
1162    /// ストップシーケンスを追加(ビルダーパターン)
1163    pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
1164        self.request_config.stop_sequences.push(sequence.into());
1165        self
1166    }
1167
1168    /// リクエスト設定をまとめて設定(ビルダーパターン)
1169    ///
1170    /// # Examples
1171    ///
1172    /// ```ignore
1173    /// let config = RequestConfig::new()
1174    ///     .with_max_tokens(4096)
1175    ///     .with_temperature(0.7);
1176    ///
1177    /// let worker = Worker::new(client)
1178    ///     .system_prompt("...")
1179    ///     .with_config(config);
1180    /// ```
1181    pub fn with_config(mut self, config: RequestConfig) -> Self {
1182        self.request_config = config;
1183        self
1184    }
1185
1186    /// 現在の設定をプロバイダに対してバリデーションする
1187    ///
1188    /// 未サポートの設定があればエラーを返す。
1189    /// チェーンの最後で呼び出すことで、設定の問題を早期に検出できる。
1190    ///
1191    /// # Examples
1192    ///
1193    /// ```ignore
1194    /// let worker = Worker::new(client)
1195    ///     .temperature(0.7)
1196    ///     .top_k(40)
1197    ///     .validate()?;  // OpenAIならtop_kがサポートされないためエラー
1198    /// ```
1199    ///
1200    /// # Returns
1201    /// * `Ok(Self)` - バリデーション成功
1202    /// * `Err(WorkerError::ConfigWarnings)` - 未サポートの設定がある
1203    pub fn validate(self) -> Result<Self, WorkerError> {
1204        let warnings = self.client.validate_config(&self.request_config);
1205        if warnings.is_empty() {
1206            Ok(self)
1207        } else {
1208            Err(WorkerError::ConfigWarnings(warnings))
1209        }
1210    }
1211
1212    /// 履歴への可変参照を取得
1213    ///
1214    /// Mutable状態でのみ利用可能。履歴を自由に編集できる。
1215    pub fn history_mut(&mut self) -> &mut Vec<Message> {
1216        &mut self.history
1217    }
1218
1219    /// 履歴を設定
1220    pub fn set_history(&mut self, messages: Vec<Message>) {
1221        self.history = messages;
1222    }
1223
1224    /// 履歴にメッセージを追加(ビルダーパターン)
1225    pub fn with_message(mut self, message: Message) -> Self {
1226        self.history.push(message);
1227        self
1228    }
1229
1230    /// 履歴にメッセージを追加
1231    pub fn push_message(&mut self, message: Message) {
1232        self.history.push(message);
1233    }
1234
1235    /// 複数のメッセージを履歴に追加(ビルダーパターン)
1236    pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
1237        self.history.extend(messages);
1238        self
1239    }
1240
1241    /// 複数のメッセージを履歴に追加
1242    pub fn extend_history(&mut self, messages: impl IntoIterator<Item = Message>) {
1243        self.history.extend(messages);
1244    }
1245
1246    /// 履歴をクリア
1247    pub fn clear_history(&mut self) {
1248        self.history.clear();
1249    }
1250
1251    /// 設定を適用(将来の拡張用)
1252    #[allow(dead_code)]
1253    pub fn config(self, _config: WorkerConfig) -> Self {
1254        self
1255    }
1256
1257    /// ロックしてCacheLocked状態へ遷移
1258    ///
1259    /// この操作により、現在のシステムプロンプトと履歴が「確定済みプレフィックス」として
1260    /// 固定される。以降は履歴への追記のみが可能となり、キャッシュヒットが保証される。
1261    pub fn lock(self) -> Worker<C, CacheLocked> {
1262        let locked_prefix_len = self.history.len();
1263        Worker {
1264            client: self.client,
1265            timeline: self.timeline,
1266            text_block_collector: self.text_block_collector,
1267            tool_call_collector: self.tool_call_collector,
1268            tools: self.tools,
1269            hooks: self.hooks,
1270            system_prompt: self.system_prompt,
1271            history: self.history,
1272            locked_prefix_len,
1273            turn_count: self.turn_count,
1274            turn_notifiers: self.turn_notifiers,
1275            request_config: self.request_config,
1276            last_run_interrupted: self.last_run_interrupted,
1277            cancel_tx: self.cancel_tx,
1278            cancel_rx: self.cancel_rx,
1279            _state: PhantomData,
1280        }
1281    }
1282
1283}
1284
1285// =============================================================================
1286// CacheLocked状態専用の実装
1287// =============================================================================
1288
1289impl<C: LlmClient> Worker<C, CacheLocked> {
1290    /// ロック時点のプレフィックス長を取得
1291    pub fn locked_prefix_len(&self) -> usize {
1292        self.locked_prefix_len
1293    }
1294
1295    /// ロックを解除してMutable状態へ戻す
1296    ///
1297    /// 注意: この操作を行うと、以降のリクエストでキャッシュがヒットしなくなる可能性がある。
1298    /// 履歴を編集する必要がある場合にのみ使用すること。
1299    pub fn unlock(self) -> Worker<C, Mutable> {
1300        Worker {
1301            client: self.client,
1302            timeline: self.timeline,
1303            text_block_collector: self.text_block_collector,
1304            tool_call_collector: self.tool_call_collector,
1305            tools: self.tools,
1306            hooks: self.hooks,
1307            system_prompt: self.system_prompt,
1308            history: self.history,
1309            locked_prefix_len: 0,
1310            turn_count: self.turn_count,
1311            turn_notifiers: self.turn_notifiers,
1312            request_config: self.request_config,
1313            last_run_interrupted: self.last_run_interrupted,
1314            cancel_tx: self.cancel_tx,
1315            cancel_rx: self.cancel_rx,
1316            _state: PhantomData,
1317        }
1318    }
1319}
1320
1321#[cfg(test)]
1322mod tests {
1323    // 基本的なテストのみ。LlmClientを使ったテストは統合テストで行う。
1324}