use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::mpsc::{Receiver, Sender};
use tokio_util::sync::CancellationToken;
use crate::ai::ai_state::{AiRequest, AiResponse, AiState};
impl AiState {
pub fn append_chunk(&mut self, chunk: &str) {
self.response.push_str(chunk);
}
pub fn send_request(&mut self, prompt: String) -> bool {
if self.request_tx.is_none() {
return false;
}
self.cancel_in_flight_request();
self.start_request();
let request_id = self.request_id;
let cancel_token = CancellationToken::new();
self.current_cancel_token = Some(cancel_token.clone());
if let Some(ref tx) = self.request_tx
&& tx
.send(AiRequest::Query {
prompt,
request_id,
cancel_token,
})
.is_ok()
{
return true;
}
self.current_cancel_token = None;
false
}
pub fn set_channels(
&mut self,
request_tx: Sender<AiRequest>,
response_rx: Receiver<AiResponse>,
) {
self.request_tx = Some(request_tx);
self.response_rx = Some(response_rx);
}
pub fn current_request_id(&self) -> u64 {
self.request_id
}
fn compute_query_hash(query: &str) -> u64 {
let mut hasher = DefaultHasher::new();
query.hash(&mut hasher);
hasher.finish()
}
pub fn is_query_changed(&self, query: &str) -> bool {
let query_hash = Self::compute_query_hash(query);
match self.last_query_hash {
None => true,
Some(last_hash) => query_hash != last_hash,
}
}
pub fn set_last_query_hash(&mut self, query: &str) {
self.last_query_hash = Some(Self::compute_query_hash(query));
}
pub fn cancel_in_flight_request(&mut self) -> bool {
if let Some(token) = self.current_cancel_token.take() {
token.cancel();
self.in_flight_request_id = None;
return true;
}
false
}
#[cfg(test)]
pub fn has_in_flight_request(&self) -> bool {
self.in_flight_request_id.is_some()
}
}