1use std::collections::HashMap;
2use std::marker::PhantomData;
3use std::sync::{Arc, Mutex};
4
5use futures::StreamExt;
6use tokio::sync::mpsc;
7use tracing::{debug, info, trace, warn};
8
9use crate::{
10 ContentPart, Message, MessageContent, Role,
11 hook::{
12 Hook, HookError, HookRegistry, OnAbort, OnPromptSubmit, OnPromptSubmitResult, OnTurnEnd,
13 OnTurnEndResult, PostToolCall, PostToolCallContext, PostToolCallResult, PreLlmRequest,
14 PreLlmRequestResult, PreToolCall, PreToolCallResult, ToolCall, ToolCallContext, ToolResult,
15 },
16 llm_client::{
17 ClientError, ConfigWarning, LlmClient, Request, RequestConfig,
18 ToolDefinition as LlmToolDefinition,
19 },
20 state::{CacheLocked, Mutable, WorkerState},
21 subscriber::{
22 ErrorSubscriberAdapter, StatusSubscriberAdapter, TextBlockSubscriberAdapter,
23 ToolUseBlockSubscriberAdapter, UsageSubscriberAdapter, WorkerSubscriber,
24 },
25 timeline::{TextBlockCollector, Timeline, ToolCallCollector},
26 tool::{Tool, ToolDefinition, ToolError, ToolMeta},
27};
28
29#[derive(Debug, thiserror::Error)]
35pub enum WorkerError {
36 #[error("Client error: {0}")]
38 Client(#[from] ClientError),
39 #[error("Tool error: {0}")]
41 Tool(#[from] ToolError),
42 #[error("Hook error: {0}")]
44 Hook(#[from] HookError),
45 #[error("Aborted: {0}")]
47 Aborted(String),
48 #[error("Cancelled")]
50 Cancelled,
51 #[error("Config warnings: {}", .0.iter().map(|w| w.to_string()).collect::<Vec<_>>().join(", "))]
53 ConfigWarnings(Vec<ConfigWarning>),
54}
55
56#[derive(Debug, thiserror::Error)]
58pub enum ToolRegistryError {
59 #[error("Tool with name '{0}' already registered")]
61 DuplicateName(String),
62}
63
64#[derive(Debug, Clone, Default)]
70pub struct WorkerConfig {
71 _private: (),
73}
74
75#[derive(Debug)]
81pub enum WorkerResult {
82 Finished,
84 Paused,
86}
87
88enum ToolExecutionResult {
90 Completed(Vec<ToolResult>),
91 Paused,
92}
93
94trait TurnNotifier: Send + Sync {
100 fn on_turn_start(&self, turn: usize);
101 fn on_turn_end(&self, turn: usize);
102}
103
104struct SubscriberTurnNotifier<S: WorkerSubscriber + 'static> {
105 subscriber: Arc<Mutex<S>>,
106}
107
108impl<S: WorkerSubscriber + 'static> TurnNotifier for SubscriberTurnNotifier<S> {
109 fn on_turn_start(&self, turn: usize) {
110 if let Ok(mut s) = self.subscriber.lock() {
111 s.on_turn_start(turn);
112 }
113 }
114
115 fn on_turn_end(&self, turn: usize) {
116 if let Ok(mut s) = self.subscriber.lock() {
117 s.on_turn_end(turn);
118 }
119 }
120}
121
122pub struct Worker<C: LlmClient, S: WorkerState = Mutable> {
161 client: C,
163 timeline: Timeline,
165 text_block_collector: TextBlockCollector,
167 tool_call_collector: ToolCallCollector,
169 tools: HashMap<String, (ToolMeta, Arc<dyn Tool>)>,
171 hooks: HookRegistry,
173 system_prompt: Option<String>,
175 history: Vec<Message>,
177 locked_prefix_len: usize,
179 turn_count: usize,
181 turn_notifiers: Vec<Box<dyn TurnNotifier>>,
183 request_config: RequestConfig,
185 last_run_interrupted: bool,
187 cancel_tx: mpsc::Sender<()>,
189 cancel_rx: mpsc::Receiver<()>,
190 _state: PhantomData<S>,
192}
193
194impl<C: LlmClient, S: WorkerState> Worker<C, S> {
199 fn reset_interruption_state(&mut self) {
200 self.last_run_interrupted = false;
201 }
202
203 pub async fn run(
208 &mut self,
209 user_input: impl Into<String>,
210 ) -> Result<WorkerResult, WorkerError> {
211 self.reset_interruption_state();
212 let mut user_message = Message::user(user_input);
214 let result = self.run_on_prompt_submit_hooks(&mut user_message).await;
215 let result = match result {
216 Ok(value) => value,
217 Err(err) => return self.finalize_interruption(Err(err)).await,
218 };
219 match result {
220 OnPromptSubmitResult::Cancel(reason) => {
221 self.last_run_interrupted = true;
222 return self.finalize_interruption(Err(WorkerError::Aborted(reason))).await;
223 }
224 OnPromptSubmitResult::Continue => {}
225 }
226 self.history.push(user_message);
227 let result = self.run_turn_loop().await;
228 self.finalize_interruption(result).await
229 }
230
231 fn drain_cancel_queue(&mut self) {
232 use tokio::sync::mpsc::error::TryRecvError;
233 loop {
234 match self.cancel_rx.try_recv() {
235 Ok(()) => continue,
236 Err(TryRecvError::Empty) | Err(TryRecvError::Disconnected) => break,
237 }
238 }
239 }
240
241 fn try_cancelled(&mut self) -> bool {
242 use tokio::sync::mpsc::error::TryRecvError;
243 match self.cancel_rx.try_recv() {
244 Ok(()) => true,
245 Err(TryRecvError::Empty) => false,
246 Err(TryRecvError::Disconnected) => true,
247 }
248 }
249
250 pub fn subscribe<Sub: WorkerSubscriber + 'static>(&mut self, subscriber: Sub) {
282 let subscriber = Arc::new(Mutex::new(subscriber));
283
284 self.timeline
286 .on_text_block(TextBlockSubscriberAdapter::new(subscriber.clone()));
287
288 self.timeline
290 .on_tool_use_block(ToolUseBlockSubscriberAdapter::new(subscriber.clone()));
291
292 self.timeline
294 .on_usage(UsageSubscriberAdapter::new(subscriber.clone()));
295 self.timeline
296 .on_status(StatusSubscriberAdapter::new(subscriber.clone()));
297 self.timeline
298 .on_error(ErrorSubscriberAdapter::new(subscriber.clone()));
299
300 self.turn_notifiers
302 .push(Box::new(SubscriberTurnNotifier { subscriber }));
303 }
304
305 pub fn register_tool(&mut self, factory: ToolDefinition) -> Result<(), ToolRegistryError> {
322 let (meta, instance) = factory();
323 if self.tools.contains_key(&meta.name) {
324 return Err(ToolRegistryError::DuplicateName(meta.name.clone()));
325 }
326 self.tools.insert(meta.name.clone(), (meta, instance));
327 Ok(())
328 }
329
330 pub fn register_tools(
332 &mut self,
333 factories: impl IntoIterator<Item = ToolDefinition>,
334 ) -> Result<(), ToolRegistryError> {
335 for factory in factories {
336 self.register_tool(factory)?;
337 }
338 Ok(())
339 }
340
341 pub fn add_on_prompt_submit_hook(&mut self, hook: impl Hook<OnPromptSubmit> + 'static) {
345 self.hooks.on_prompt_submit.push(Box::new(hook));
346 }
347
348 pub fn add_pre_llm_request_hook(&mut self, hook: impl Hook<PreLlmRequest> + 'static) {
352 self.hooks.pre_llm_request.push(Box::new(hook));
353 }
354
355 pub fn add_pre_tool_call_hook(&mut self, hook: impl Hook<PreToolCall> + 'static) {
357 self.hooks.pre_tool_call.push(Box::new(hook));
358 }
359
360 pub fn add_post_tool_call_hook(&mut self, hook: impl Hook<PostToolCall> + 'static) {
362 self.hooks.post_tool_call.push(Box::new(hook));
363 }
364
365 pub fn add_on_turn_end_hook(&mut self, hook: impl Hook<OnTurnEnd> + 'static) {
367 self.hooks.on_turn_end.push(Box::new(hook));
368 }
369
370 pub fn add_on_abort_hook(&mut self, hook: impl Hook<OnAbort> + 'static) {
372 self.hooks.on_abort.push(Box::new(hook));
373 }
374
375 pub fn timeline_mut(&mut self) -> &mut Timeline {
377 &mut self.timeline
378 }
379
380 pub fn history(&self) -> &[Message] {
382 &self.history
383 }
384
385 pub fn get_system_prompt(&self) -> Option<&str> {
387 self.system_prompt.as_deref()
388 }
389
390 pub fn turn_count(&self) -> usize {
392 self.turn_count
393 }
394
395 pub fn request_config(&self) -> &RequestConfig {
397 &self.request_config
398 }
399
400 pub fn set_max_tokens(&mut self, max_tokens: u32) {
410 self.request_config.max_tokens = Some(max_tokens);
411 }
412
413 pub fn set_temperature(&mut self, temperature: f32) {
424 self.request_config.temperature = Some(temperature);
425 }
426
427 pub fn set_top_p(&mut self, top_p: f32) {
435 self.request_config.top_p = Some(top_p);
436 }
437
438 pub fn set_top_k(&mut self, top_k: u32) {
448 self.request_config.top_k = Some(top_k);
449 }
450
451 pub fn add_stop_sequence(&mut self, sequence: impl Into<String>) {
459 self.request_config.stop_sequences.push(sequence.into());
460 }
461
462 pub fn clear_stop_sequences(&mut self) {
464 self.request_config.stop_sequences.clear();
465 }
466
467 pub fn cancel_sender(&self) -> mpsc::Sender<()> {
469 self.cancel_tx.clone()
470 }
471
472 pub fn set_request_config(&mut self, config: RequestConfig) {
474 self.request_config = config;
475 }
476
477 pub fn cancel(&self) {
499 let _ = self.cancel_tx.try_send(());
500 }
501
502 pub fn is_cancelled(&mut self) -> bool {
504 self.try_cancelled()
505 }
506
507 pub fn last_run_interrupted(&self) -> bool {
509 self.last_run_interrupted
510 }
511
512 fn build_tool_definitions(&self) -> Vec<LlmToolDefinition> {
514 self.tools
515 .values()
516 .map(|(meta, _)| {
517 LlmToolDefinition::new(&meta.name)
518 .description(&meta.description)
519 .input_schema(meta.input_schema.clone())
520 })
521 .collect()
522 }
523
524 fn build_assistant_message(
526 &self,
527 text_blocks: &[String],
528 tool_calls: &[ToolCall],
529 ) -> Option<Message> {
530 if text_blocks.is_empty() && tool_calls.is_empty() {
532 return None;
533 }
534
535 if tool_calls.is_empty() {
537 let text = text_blocks.join("");
538 return Some(Message::assistant(text));
539 }
540
541 let mut parts = Vec::new();
543
544 for text in text_blocks {
546 if !text.is_empty() {
547 parts.push(ContentPart::Text { text: text.clone() });
548 }
549 }
550
551 for call in tool_calls {
553 parts.push(ContentPart::ToolUse {
554 id: call.id.clone(),
555 name: call.name.clone(),
556 input: call.input.clone(),
557 });
558 }
559
560 Some(Message {
561 role: Role::Assistant,
562 content: MessageContent::Parts(parts),
563 })
564 }
565
566 fn build_request(
568 &self,
569 tool_definitions: &[LlmToolDefinition],
570 context: &[Message],
571 ) -> Request {
572 let mut request = Request::new();
573
574 if let Some(ref system) = self.system_prompt {
576 request = request.system(system);
577 }
578
579 for msg in context {
581 request = request.message(crate::llm_client::Message {
583 role: match msg.role {
584 Role::User => crate::llm_client::Role::User,
585 Role::Assistant => crate::llm_client::Role::Assistant,
586 },
587 content: match &msg.content {
588 MessageContent::Text(t) => crate::llm_client::MessageContent::Text(t.clone()),
589 MessageContent::ToolResult {
590 tool_use_id,
591 content,
592 } => crate::llm_client::MessageContent::ToolResult {
593 tool_use_id: tool_use_id.clone(),
594 content: content.clone(),
595 },
596 MessageContent::Parts(parts) => crate::llm_client::MessageContent::Parts(
597 parts
598 .iter()
599 .map(|p| match p {
600 ContentPart::Text { text } => {
601 crate::llm_client::ContentPart::Text { text: text.clone() }
602 }
603 ContentPart::ToolUse { id, name, input } => {
604 crate::llm_client::ContentPart::ToolUse {
605 id: id.clone(),
606 name: name.clone(),
607 input: input.clone(),
608 }
609 }
610 ContentPart::ToolResult {
611 tool_use_id,
612 content,
613 } => crate::llm_client::ContentPart::ToolResult {
614 tool_use_id: tool_use_id.clone(),
615 content: content.clone(),
616 },
617 })
618 .collect(),
619 ),
620 },
621 });
622 }
623
624 for tool_def in tool_definitions {
626 request = request.tool(tool_def.clone());
627 }
628
629 request = request.config(self.request_config.clone());
631
632 request
633 }
634
635 async fn run_on_prompt_submit_hooks(
639 &self,
640 message: &mut Message,
641 ) -> Result<OnPromptSubmitResult, WorkerError> {
642 for hook in &self.hooks.on_prompt_submit {
643 let result = hook.call(message).await?;
644 match result {
645 OnPromptSubmitResult::Continue => continue,
646 OnPromptSubmitResult::Cancel(reason) => {
647 return Ok(OnPromptSubmitResult::Cancel(reason));
648 }
649 }
650 }
651 Ok(OnPromptSubmitResult::Continue)
652 }
653
654 async fn run_pre_llm_request_hooks(
658 &self,
659 ) -> Result<(PreLlmRequestResult, Vec<Message>), WorkerError> {
660 let mut temp_context = self.history.clone();
661 for hook in &self.hooks.pre_llm_request {
662 let result = hook.call(&mut temp_context).await?;
663 match result {
664 PreLlmRequestResult::Continue => continue,
665 PreLlmRequestResult::Cancel(reason) => {
666 return Ok((PreLlmRequestResult::Cancel(reason), temp_context));
667 }
668 }
669 }
670 Ok((PreLlmRequestResult::Continue, temp_context))
671 }
672
673 async fn run_on_turn_end_hooks(&self) -> Result<OnTurnEndResult, WorkerError> {
675 let mut temp_messages = self.history.clone();
676 for hook in &self.hooks.on_turn_end {
677 let result = hook.call(&mut temp_messages).await?;
678 match result {
679 OnTurnEndResult::Finish => continue,
680 OnTurnEndResult::ContinueWithMessages(msgs) => {
681 return Ok(OnTurnEndResult::ContinueWithMessages(msgs));
682 }
683 OnTurnEndResult::Paused => return Ok(OnTurnEndResult::Paused),
684 }
685 }
686 Ok(OnTurnEndResult::Finish)
687 }
688
689 async fn run_on_abort_hooks(&self, reason: &str) -> Result<(), WorkerError> {
691 let mut reason = reason.to_string();
692 for hook in &self.hooks.on_abort {
693 hook.call(&mut reason).await?;
694 }
695 Ok(())
696 }
697
698 async fn finalize_interruption<T>(
699 &mut self,
700 result: Result<T, WorkerError>,
701 ) -> Result<T, WorkerError> {
702 match result {
703 Ok(value) => Ok(value),
704 Err(err) => {
705 self.last_run_interrupted = true;
706 let reason = match &err {
707 WorkerError::Aborted(reason) => reason.clone(),
708 WorkerError::Cancelled => "Cancelled".to_string(),
709 _ => err.to_string(),
710 };
711 if let Err(hook_err) = self.run_on_abort_hooks(&reason).await {
712 self.last_run_interrupted = true;
713 return Err(hook_err);
714 }
715 Err(err)
716 }
717 }
718 }
719
720 fn get_pending_tool_calls(&self) -> Option<Vec<ToolCall>> {
722 let last_msg = self.history.last()?;
723 if last_msg.role != Role::Assistant {
724 return None;
725 }
726
727 let mut calls = Vec::new();
728 if let MessageContent::Parts(parts) = &last_msg.content {
729 for part in parts {
730 if let ContentPart::ToolUse { id, name, input } = part {
731 calls.push(ToolCall {
732 id: id.clone(),
733 name: name.clone(),
734 input: input.clone(),
735 });
736 }
737 }
738 }
739
740 if calls.is_empty() { None } else { Some(calls) }
741 }
742
743 async fn execute_tools(
748 &mut self,
749 tool_calls: Vec<ToolCall>,
750 ) -> Result<ToolExecutionResult, WorkerError> {
751 use futures::future::join_all;
752
753 let mut call_info_map = HashMap::new();
756
757 let mut approved_calls = Vec::new();
759 for mut tool_call in tool_calls {
760 if let Some((meta, tool)) = self.tools.get(&tool_call.name) {
762 let mut context = ToolCallContext {
764 call: tool_call.clone(),
765 meta: meta.clone(),
766 tool: tool.clone(),
767 };
768
769 let mut skip = false;
770 for hook in &self.hooks.pre_tool_call {
771 let result = hook
772 .call(&mut context)
773 .await
774 .inspect_err(|_| self.last_run_interrupted = true)?;
775 match result {
776 PreToolCallResult::Continue => {}
777 PreToolCallResult::Skip => {
778 skip = true;
779 break;
780 }
781 PreToolCallResult::Abort(reason) => {
782 self.last_run_interrupted = true;
783 return Err(WorkerError::Aborted(reason));
784 }
785 PreToolCallResult::Pause => {
786 self.last_run_interrupted = true;
787 return Ok(ToolExecutionResult::Paused);
788 }
789 }
790 }
791
792 tool_call = context.call;
794
795 if !skip {
797 call_info_map.insert(
798 tool_call.id.clone(),
799 (tool_call.clone(), meta.clone(), tool.clone()),
800 );
801 approved_calls.push(tool_call);
802 }
803 } else {
804 approved_calls.push(tool_call);
807 }
808 }
809
810 let futures: Vec<_> = approved_calls
812 .into_iter()
813 .map(|tool_call| {
814 let tools = &self.tools;
815 async move {
816 if let Some((_, tool)) = tools.get(&tool_call.name) {
817 let input_json =
818 serde_json::to_string(&tool_call.input).unwrap_or_default();
819 match tool.execute(&input_json).await {
820 Ok(content) => ToolResult::success(&tool_call.id, content),
821 Err(e) => ToolResult::error(&tool_call.id, e.to_string()),
822 }
823 } else {
824 ToolResult::error(
825 &tool_call.id,
826 format!("Tool '{}' not found", tool_call.name),
827 )
828 }
829 }
830 })
831 .collect();
832
833 let mut results = tokio::select! {
835 results = join_all(futures) => results,
836 cancel = self.cancel_rx.recv() => {
837 if cancel.is_some() {
838 info!("Tool execution cancelled");
839 }
840 self.timeline.abort_current_block();
841 self.last_run_interrupted = true;
842 return Err(WorkerError::Cancelled);
843 }
844 };
845
846 for tool_result in &mut results {
848 if let Some((tool_call, meta, tool)) = call_info_map.get(&tool_result.tool_use_id) {
850 let mut context = PostToolCallContext {
851 call: tool_call.clone(),
852 result: tool_result.clone(),
853 meta: meta.clone(),
854 tool: tool.clone(),
855 };
856
857 for hook in &self.hooks.post_tool_call {
858 let result = hook
859 .call(&mut context)
860 .await
861 .inspect_err(|_| self.last_run_interrupted = true)?;
862 match result {
863 PostToolCallResult::Continue => {}
864 PostToolCallResult::Abort(reason) => {
865 self.last_run_interrupted = true;
866 return Err(WorkerError::Aborted(reason));
867 }
868 }
869 }
870 *tool_result = context.result;
872 }
873 }
874
875 Ok(ToolExecutionResult::Completed(results))
876 }
877
878 async fn run_turn_loop(&mut self) -> Result<WorkerResult, WorkerError> {
880 self.reset_interruption_state();
881 self.drain_cancel_queue();
882 let tool_definitions = self.build_tool_definitions();
883
884 info!(
885 message_count = self.history.len(),
886 tool_count = tool_definitions.len(),
887 "Starting worker run"
888 );
889
890 if let Some(tool_calls) = self.get_pending_tool_calls() {
892 info!("Resuming pending tool calls");
893 match self.execute_tools(tool_calls).await {
894 Ok(ToolExecutionResult::Paused) => {
895 self.last_run_interrupted = true;
896 return Ok(WorkerResult::Paused);
897 }
898 Ok(ToolExecutionResult::Completed(results)) => {
899 for result in results {
900 self.history
901 .push(Message::tool_result(&result.tool_use_id, &result.content));
902 }
903 }
905 Err(err) => {
906 self.last_run_interrupted = true;
907 return Err(err);
908 }
909 }
910 }
911
912 loop {
913 if self.try_cancelled() {
915 info!("Execution cancelled");
916 self.timeline.abort_current_block();
917 self.last_run_interrupted = true;
918 return Err(WorkerError::Cancelled);
919 }
920
921 let current_turn = self.turn_count;
923 debug!(turn = current_turn, "Turn start");
924 for notifier in &self.turn_notifiers {
925 notifier.on_turn_start(current_turn);
926 }
927
928 let (control, request_context) = self
930 .run_pre_llm_request_hooks()
931 .await
932 .inspect_err(|_| self.last_run_interrupted = true)?;
933 match control {
934 PreLlmRequestResult::Cancel(reason) => {
935 info!(reason = %reason, "Aborted by hook");
936 for notifier in &self.turn_notifiers {
937 notifier.on_turn_end(current_turn);
938 }
939 self.last_run_interrupted = true;
940 return Err(WorkerError::Aborted(reason));
941 }
942 PreLlmRequestResult::Continue => {}
943 }
944
945 let request = self.build_request(&tool_definitions, &request_context);
947 debug!(
948 message_count = request.messages.len(),
949 tool_count = request.tools.len(),
950 has_system = request.system_prompt.is_some(),
951 "Sending request to LLM"
952 );
953
954 debug!("Starting stream...");
956 let mut event_count = 0;
957
958 let mut stream = tokio::select! {
960 stream_result = self.client.stream(request) => stream_result
961 .inspect_err(|_| self.last_run_interrupted = true)?,
962 cancel = self.cancel_rx.recv() => {
963 if cancel.is_some() {
964 info!("Cancelled before stream started");
965 }
966 self.timeline.abort_current_block();
967 self.last_run_interrupted = true;
968 return Err(WorkerError::Cancelled);
969 }
970 };
971
972 loop {
973 tokio::select! {
974 event_result = stream.next() => {
976 match event_result {
977 Some(result) => {
978 match &result {
979 Ok(event) => {
980 trace!(event = ?event, "Received event");
981 event_count += 1;
982 }
983 Err(e) => {
984 warn!(error = %e, "Stream error");
985 }
986 }
987 let event = result
988 .inspect_err(|_| self.last_run_interrupted = true)?;
989 let timeline_event: crate::timeline::event::Event = event.into();
990 self.timeline.dispatch(&timeline_event);
991 }
992 None => break, }
994 }
995 cancel = self.cancel_rx.recv() => {
997 if cancel.is_some() {
998 info!("Stream cancelled");
999 }
1000 self.timeline.abort_current_block();
1001 self.last_run_interrupted = true;
1002 return Err(WorkerError::Cancelled);
1003 }
1004 }
1005 }
1006 debug!(event_count = event_count, "Stream completed");
1007
1008 for notifier in &self.turn_notifiers {
1010 notifier.on_turn_end(current_turn);
1011 }
1012 self.turn_count += 1;
1013
1014 let text_blocks = self.text_block_collector.take_collected();
1016 let tool_calls = self.tool_call_collector.take_collected();
1017
1018 let assistant_message = self.build_assistant_message(&text_blocks, &tool_calls);
1020 if let Some(msg) = assistant_message {
1021 self.history.push(msg);
1022 }
1023
1024 if tool_calls.is_empty() {
1025 let turn_result = self
1027 .run_on_turn_end_hooks()
1028 .await
1029 .inspect_err(|_| self.last_run_interrupted = true)?;
1030 match turn_result {
1031 OnTurnEndResult::Finish => {
1032 self.last_run_interrupted = false;
1033 return Ok(WorkerResult::Finished);
1034 }
1035 OnTurnEndResult::ContinueWithMessages(additional) => {
1036 self.history.extend(additional);
1037 continue;
1038 }
1039 OnTurnEndResult::Paused => {
1040 self.last_run_interrupted = true;
1041 return Ok(WorkerResult::Paused);
1042 }
1043 }
1044 }
1045
1046 match self.execute_tools(tool_calls).await {
1048 Ok(ToolExecutionResult::Paused) => {
1049 self.last_run_interrupted = true;
1050 return Ok(WorkerResult::Paused);
1051 }
1052 Ok(ToolExecutionResult::Completed(results)) => {
1053 for result in results {
1054 self.history
1055 .push(Message::tool_result(&result.tool_use_id, &result.content));
1056 }
1057 }
1058 Err(err) => {
1059 self.last_run_interrupted = true;
1060 return Err(err);
1061 }
1062 }
1063 }
1064 }
1065
1066 pub async fn resume(&mut self) -> Result<WorkerResult, WorkerError> {
1070 self.reset_interruption_state();
1071 let result = self.run_turn_loop().await;
1072 self.finalize_interruption(result).await
1073 }
1074}
1075
1076impl<C: LlmClient> Worker<C, Mutable> {
1081 pub fn new(client: C) -> Self {
1083 let text_block_collector = TextBlockCollector::new();
1084 let tool_call_collector = ToolCallCollector::new();
1085 let mut timeline = Timeline::new();
1086 let (cancel_tx, cancel_rx) = mpsc::channel(1);
1087
1088 timeline.on_text_block(text_block_collector.clone());
1090 timeline.on_tool_use_block(tool_call_collector.clone());
1091
1092 Self {
1093 client,
1094 timeline,
1095 text_block_collector,
1096 tool_call_collector,
1097 tools: HashMap::new(),
1098 hooks: HookRegistry::new(),
1099 system_prompt: None,
1100 history: Vec::new(),
1101 locked_prefix_len: 0,
1102 turn_count: 0,
1103 turn_notifiers: Vec::new(),
1104 request_config: RequestConfig::default(),
1105 last_run_interrupted: false,
1106 cancel_tx,
1107 cancel_rx,
1108 _state: PhantomData,
1109 }
1110 }
1111
1112 pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
1114 self.system_prompt = Some(prompt.into());
1115 self
1116 }
1117
1118 pub fn set_system_prompt(&mut self, prompt: impl Into<String>) {
1120 self.system_prompt = Some(prompt.into());
1121 }
1122
1123 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
1133 self.request_config.max_tokens = Some(max_tokens);
1134 self
1135 }
1136
1137 pub fn temperature(mut self, temperature: f32) -> Self {
1146 self.request_config.temperature = Some(temperature);
1147 self
1148 }
1149
1150 pub fn top_p(mut self, top_p: f32) -> Self {
1152 self.request_config.top_p = Some(top_p);
1153 self
1154 }
1155
1156 pub fn top_k(mut self, top_k: u32) -> Self {
1158 self.request_config.top_k = Some(top_k);
1159 self
1160 }
1161
1162 pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
1164 self.request_config.stop_sequences.push(sequence.into());
1165 self
1166 }
1167
1168 pub fn with_config(mut self, config: RequestConfig) -> Self {
1182 self.request_config = config;
1183 self
1184 }
1185
1186 pub fn validate(self) -> Result<Self, WorkerError> {
1204 let warnings = self.client.validate_config(&self.request_config);
1205 if warnings.is_empty() {
1206 Ok(self)
1207 } else {
1208 Err(WorkerError::ConfigWarnings(warnings))
1209 }
1210 }
1211
1212 pub fn history_mut(&mut self) -> &mut Vec<Message> {
1216 &mut self.history
1217 }
1218
1219 pub fn set_history(&mut self, messages: Vec<Message>) {
1221 self.history = messages;
1222 }
1223
1224 pub fn with_message(mut self, message: Message) -> Self {
1226 self.history.push(message);
1227 self
1228 }
1229
1230 pub fn push_message(&mut self, message: Message) {
1232 self.history.push(message);
1233 }
1234
1235 pub fn with_messages(mut self, messages: impl IntoIterator<Item = Message>) -> Self {
1237 self.history.extend(messages);
1238 self
1239 }
1240
1241 pub fn extend_history(&mut self, messages: impl IntoIterator<Item = Message>) {
1243 self.history.extend(messages);
1244 }
1245
1246 pub fn clear_history(&mut self) {
1248 self.history.clear();
1249 }
1250
1251 #[allow(dead_code)]
1253 pub fn config(self, _config: WorkerConfig) -> Self {
1254 self
1255 }
1256
1257 pub fn lock(self) -> Worker<C, CacheLocked> {
1262 let locked_prefix_len = self.history.len();
1263 Worker {
1264 client: self.client,
1265 timeline: self.timeline,
1266 text_block_collector: self.text_block_collector,
1267 tool_call_collector: self.tool_call_collector,
1268 tools: self.tools,
1269 hooks: self.hooks,
1270 system_prompt: self.system_prompt,
1271 history: self.history,
1272 locked_prefix_len,
1273 turn_count: self.turn_count,
1274 turn_notifiers: self.turn_notifiers,
1275 request_config: self.request_config,
1276 last_run_interrupted: self.last_run_interrupted,
1277 cancel_tx: self.cancel_tx,
1278 cancel_rx: self.cancel_rx,
1279 _state: PhantomData,
1280 }
1281 }
1282
1283}
1284
1285impl<C: LlmClient> Worker<C, CacheLocked> {
1290 pub fn locked_prefix_len(&self) -> usize {
1292 self.locked_prefix_len
1293 }
1294
1295 pub fn unlock(self) -> Worker<C, Mutable> {
1300 Worker {
1301 client: self.client,
1302 timeline: self.timeline,
1303 text_block_collector: self.text_block_collector,
1304 tool_call_collector: self.tool_call_collector,
1305 tools: self.tools,
1306 hooks: self.hooks,
1307 system_prompt: self.system_prompt,
1308 history: self.history,
1309 locked_prefix_len: 0,
1310 turn_count: self.turn_count,
1311 turn_notifiers: self.turn_notifiers,
1312 request_config: self.request_config,
1313 last_run_interrupted: self.last_run_interrupted,
1314 cancel_tx: self.cancel_tx,
1315 cancel_rx: self.cancel_rx,
1316 _state: PhantomData,
1317 }
1318 }
1319}
1320
1321#[cfg(test)]
1322mod tests {
1323 }