use std::panic::{self, AssertUnwindSafe};
use std::sync::mpsc::{Receiver, Sender};
use tokio_util::sync::CancellationToken;
use super::ai_state::{AiRequest, AiResponse};
use super::provider::{AiError, AsyncAiProvider};
use crate::config::ai_types::AiConfig;
pub fn spawn_worker(
config: &AiConfig,
request_rx: Receiver<AiRequest>,
response_tx: Sender<AiResponse>,
) {
let provider_result = AsyncAiProvider::from_config(config);
std::thread::spawn(move || {
let response_tx_clone = response_tx.clone();
let prev_hook = panic::take_hook();
panic::set_hook(Box::new(move |panic_info| {
let panic_msg = if let Some(s) = panic_info.payload().downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = panic_info.payload().downcast_ref::<String>() {
s.clone()
} else {
"Unknown panic in AI worker".to_string()
};
log::error!(
"AI worker panic: {} at {:?}",
panic_msg,
panic_info.location()
);
let _ = response_tx_clone.send(AiResponse::Error(format!(
"AI worker crashed: {}",
panic_msg
)));
}));
let result = panic::catch_unwind(AssertUnwindSafe(|| {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create tokio runtime");
rt.block_on(worker_loop(provider_result, request_rx, response_tx));
}));
panic::set_hook(prev_hook);
if let Err(e) = result {
let panic_msg = if let Some(s) = e.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else {
"Unknown panic".to_string()
};
log::error!("AI worker thread panicked: {}", panic_msg);
}
});
}
async fn worker_loop(
provider_result: Result<AsyncAiProvider, AiError>,
request_rx: Receiver<AiRequest>,
response_tx: Sender<AiResponse>,
) {
let provider = provider_result.ok();
while let Ok(request) = request_rx.recv() {
match request {
AiRequest::Query {
prompt,
request_id,
cancel_token,
} => {
handle_query_async(&provider, &prompt, request_id, cancel_token, &response_tx)
.await;
}
}
}
}
async fn handle_query_async(
provider: &Option<AsyncAiProvider>,
prompt: &str,
request_id: u64,
cancel_token: CancellationToken,
response_tx: &Sender<AiResponse>,
) {
if cancel_token.is_cancelled() {
let _ = response_tx.send(AiResponse::Cancelled { request_id });
return;
}
let provider = match provider {
Some(p) => p,
None => {
let _ = response_tx.send(AiResponse::Error(
"AI not configured. Enable AI in your config file with 'enabled = true' and configure a provider. See https://github.com/bellicose100xp/jiq#configuration for setup instructions.".to_string(),
));
return;
}
};
match provider
.stream_with_cancel(prompt, request_id, cancel_token, response_tx.clone())
.await
{
Ok(()) => {
let _ = response_tx.send(AiResponse::Complete { request_id });
}
Err(AiError::Cancelled) => {
let _ = response_tx.send(AiResponse::Cancelled { request_id });
}
Err(e) => {
let _ = response_tx.send(AiResponse::Error(e.to_string()));
}
}
}
#[cfg(test)]
#[path = "worker_tests.rs"]
mod worker_tests;