use std::{
collections::HashMap,
ops::{ControlFlow, Deref},
sync::Arc,
};
use tokio::{
process::{Child, ChildStderr},
sync::{Notify, RwLock, RwLockReadGuard},
};
use tokio_stream::StreamExt;
use crate::{
engine::{Engine, EngineStdin, EngineStdout},
*,
};
pub struct Analyzer<W: WarningHandling = WarningsAsErrors> {
stdin: EngineStdin,
pub stderr: Option<ChildStderr>,
pub child_process: Child,
next_id: u32,
pending_requests: Arc<RwLock<PendingRequests<W>>>,
}
impl<W: WarningHandling> Analyzer<W> {
pub async fn analyze(
&mut self,
request: AnalysisRequest,
) -> WarningResult<Option<AnalysisResult>, W> {
self.start_analyze(request).await?.finish().await
}
pub async fn analyze_position(
&mut self,
request: AnalysisRequest,
position: usize,
) -> WarningResult<Option<AnalysisResult>, W> {
self.start_analyze_position(request, position)
.await?
.finish()
.await
}
pub async fn analyze_game(
&mut self,
request: AnalysisRequest,
) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
self.start_analyze_game(request).await?.finish().await
}
pub async fn analyze_positions(
&mut self,
request: AnalysisRequest,
analyze_turns: Vec<usize>,
) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
self.start_analyze_positions(request, analyze_turns)
.await?
.finish()
.await
}
pub async fn start_analyze(&mut self, request: AnalysisRequest) -> Result<AnalysisProgress<W>> {
let position = request.moves.len();
self.start_analyze_position(request, position).await
}
pub async fn start_analyze_position(
&mut self,
request: AnalysisRequest,
position: usize,
) -> Result<AnalysisProgress<W>> {
Ok(self
.start_analyze_positions(request, vec![position])
.await?
.into_positions()
.remove(&position)
.expect("position analysis should be available"))
}
pub async fn start_analyze_game(
&mut self,
request: AnalysisRequest,
) -> Result<GameAnalysisProgress<W>> {
let positions = (0..=request.moves.len()).collect();
self.start_analyze_positions(request, positions).await
}
pub async fn start_analyze_positions(
&mut self,
request: AnalysisRequest,
analyze_turns: Vec<usize>,
) -> Result<GameAnalysisProgress<W>> {
self.start_analyze_positions_impl(request, analyze_turns, None)
.await
}
pub async fn analyze_game_prioritized(
&mut self,
request: AnalysisRequest,
priorities: Vec<i32>,
) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
self.start_analyze_game_prioritized(request, priorities)
.await?
.finish()
.await
}
pub async fn analyze_positions_prioritized(
&mut self,
request: AnalysisRequest,
analyze_turns: Vec<usize>,
priorities: Vec<i32>,
) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
self.start_analyze_positions_prioritized(request, analyze_turns, priorities)
.await?
.finish()
.await
}
pub async fn start_analyze_game_prioritized(
&mut self,
request: AnalysisRequest,
priorities: Vec<i32>,
) -> Result<GameAnalysisProgress<W>> {
let positions = (0..=request.moves.len()).collect();
self.start_analyze_positions_prioritized(request, positions, priorities)
.await
}
pub async fn start_analyze_positions_prioritized(
&mut self,
request: AnalysisRequest,
analyze_turns: Vec<usize>,
priorities: Vec<i32>,
) -> Result<GameAnalysisProgress<W>> {
self.start_analyze_positions_impl(request, analyze_turns, Some(priorities))
.await
}
async fn start_analyze_positions_impl(
&mut self,
request: AnalysisRequest,
analyze_turns: Vec<usize>,
priorities: Option<Vec<i32>>,
) -> Result<GameAnalysisProgress<W>> {
let id = self.generate_id();
let mut senders = HashMap::new();
let mut positions = HashMap::new();
for position in &analyze_turns {
let (sender, receiver) = channel(W::ok(None));
senders.insert(*position, sender);
positions.insert(
*position,
AnalysisProgress::<W> {
receiver,
id: id.clone(),
turn_number: *position,
},
);
}
let pending_request = PendingRequest::<W> {
positions: senders,
width: request.board_x_size,
height: request.board_y_size,
};
let mut pending = self.pending_requests.write().await;
self.stdin
.send(&engine::Request::Analyze(request.into_engine_request(
id.clone(),
analyze_turns,
priorities,
)))
.await?;
pending.requests.insert(id.clone(), pending_request);
Ok(GameAnalysisProgress::<W> { id, positions })
}
pub async fn query_version(&mut self) -> WarningResult<VersionInfo, W> {
let id = self.generate_id();
let (sender, receiver) = channel(W::ok(VersionInfo {
version: String::new(),
git_hash: String::new(),
}));
let mut pending = self.pending_requests.write().await;
self.stdin
.send(&engine::Request::QueryVersion { id: id.clone() })
.await?;
pending.query_version_requests.insert(id, sender);
drop(pending);
receiver.finish().await
}
pub async fn clear_cache(&mut self) -> WarningResult<(), W> {
let id = self.generate_id();
let (sender, receiver) = channel(W::ok(()));
let mut pending = self.pending_requests.write().await;
self.stdin
.send(&engine::Request::ClearCache { id: id.clone() })
.await?;
pending.clear_cache_requests.insert(id, sender);
drop(pending);
receiver.finish().await
}
pub async fn terminate(&mut self, progress: &AnalysisProgress) -> WarningResult<(), W> {
self.terminate_impl(progress.id.clone(), Some(vec![progress.turn_number]))
.await
}
pub async fn terminate_game(
&mut self,
progress: &GameAnalysisProgress,
) -> WarningResult<(), W> {
self.terminate_impl(progress.id.clone(), None).await
}
pub async fn terminate_positions(
&mut self,
progress: &GameAnalysisProgress,
turn_numbers: Vec<usize>,
) -> WarningResult<(), W> {
self.terminate_impl(progress.id.clone(), Some(turn_numbers))
.await
}
async fn terminate_impl(
&mut self,
terminate_id: String,
turn_numbers: Option<Vec<usize>>,
) -> WarningResult<(), W> {
let id = self.generate_id();
let (sender, receiver) = channel(W::ok(()));
let mut pending = self.pending_requests.write().await;
self.stdin
.send(&engine::Request::Terminate {
id: id.clone(),
terminate_id,
turn_numbers,
})
.await?;
pending.terminate_requests.insert(id, sender);
drop(pending);
receiver.finish().await
}
pub async fn terminate_all(&mut self) -> WarningResult<(), W> {
self.terminate_all_impl(None).await
}
pub async fn terminate_all_positions(
&mut self,
turn_numbers: Vec<usize>,
) -> WarningResult<(), W> {
self.terminate_all_impl(Some(turn_numbers)).await
}
async fn terminate_all_impl(
&mut self,
turn_numbers: Option<Vec<usize>>,
) -> WarningResult<(), W> {
let id = self.generate_id();
let (sender, receiver) = channel(W::ok(()));
let mut pending = self.pending_requests.write().await;
self.stdin
.send(&engine::Request::TerminateAll {
id: id.clone(),
turn_numbers,
})
.await?;
pending.terminate_all_requests.insert(id, sender);
drop(pending);
receiver.finish().await
}
pub async fn query_models(&mut self) -> WarningResult<Vec<Model>, W> {
let id = self.generate_id();
let (sender, receiver) = channel(W::ok(vec![]));
let mut pending = self.pending_requests.write().await;
self.stdin
.send(&engine::Request::QueryModels { id: id.clone() })
.await?;
pending.query_models_requests.insert(id, sender);
drop(pending);
receiver.finish().await
}
fn generate_id(&mut self) -> String {
let id = self.next_id.to_string();
self.next_id += 1;
id
}
}
impl<W: WarningHandling + Default + Clone + 'static> From<Engine> for Analyzer<W>
where
W::OkType<Option<AnalysisResult>>: Send + Sync,
W::OkType<VersionInfo>: Send + Sync,
W::OkType<()>: Send + Sync,
W::OkType<Vec<Model>>: Send + Sync,
{
fn from(engine: Engine) -> Self {
let client = Self {
stdin: engine.stdin,
stderr: engine.stderr,
child_process: engine.child_process,
next_id: 1,
pending_requests: Arc::default(),
};
tokio::spawn(handle_responses(
engine.stdout,
client.pending_requests.clone(),
));
client
}
}
async fn handle_responses<W: WarningHandling>(
mut stdout: EngineStdout,
pending: Arc<RwLock<PendingRequests<W>>>,
) {
while let Some(response) = stdout.next().await {
let response = match response {
Ok(response) => response,
Err(e) => {
pending.write().await.poison_all(e).await;
continue;
}
};
match response {
engine::Response::Analyze(response) => {
let id = response.id.clone();
let turn_number = response.turn_number;
let is_during_search = response.is_during_search;
let mut pending = pending.write().await;
if let Some(request) = pending.requests.get_mut(&id) {
if let Some(sender) = request.positions.get(&turn_number) {
let result = Some(AnalysisResult::from_engine_response(
response,
request.width,
request.height,
));
sender.send_modify(|r| W::set_result(r, result)).await;
if !is_during_search {
request.positions.remove(&turn_number);
}
}
if request.positions.is_empty() {
pending.requests.remove(&id);
}
}
}
engine::Response::NoResults { id, turn_number } => {
let mut pending = pending.write().await;
if let Some(request) = pending.requests.get_mut(&id) {
request.positions.remove(&turn_number);
if request.positions.is_empty() {
pending.requests.remove(&id);
}
}
}
engine::Response::QueryVersion {
id,
version,
git_hash,
} => {
if let Some(sender) = pending.write().await.query_version_requests.remove(&id) {
sender
.send_modify(|r| W::set_result(r, VersionInfo { version, git_hash }))
.await;
}
}
engine::Response::ClearCache { id } => {
if let Some(sender) = pending.write().await.clear_cache_requests.remove(&id) {
sender.send_modify(|r| W::set_result(r, ())).await;
}
}
engine::Response::Terminate { id, .. } => {
if let Some(sender) = pending.write().await.terminate_requests.remove(&id) {
sender.send_modify(|r| W::set_result(r, ())).await;
}
}
engine::Response::TerminateAll { id, .. } => {
if let Some(sender) = pending.write().await.terminate_all_requests.remove(&id) {
sender.send_modify(|r| W::set_result(r, ())).await;
}
}
engine::Response::QueryModels { id, models } => {
if let Some(sender) = pending.write().await.query_models_requests.remove(&id) {
sender.send_modify(|r| W::set_result(r, models)).await;
}
}
engine::Response::GeneralError { error } => {
pending
.write()
.await
.poison_all(Error::KataGoGeneralError { error })
.await;
}
engine::Response::FieldError { id, error, field } => {
pending
.write()
.await
.poison(&id, Error::KataGoFieldError { error, field })
.await;
}
engine::Response::FieldWarning { id, warning, field } => {
pending
.write()
.await
.add_warning(&id, Warning { warning, field })
.await;
}
};
}
let mut pending = pending.write().await;
pending.requests.clear();
pending.query_version_requests.clear();
pending.clear_cache_requests.clear();
pending.terminate_requests.clear();
pending.terminate_all_requests.clear();
pending.query_models_requests.clear();
}
impl<W: WarningHandling> std::fmt::Debug for Analyzer<W>
where
W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
W::OkType<VersionInfo>: std::fmt::Debug,
W::OkType<()>: std::fmt::Debug,
W::OkType<Vec<Model>>: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Analyzer")
.field("stdin", &self.stdin)
.field("stderr", &self.stderr)
.field("child_process", &self.child_process)
.field("next_id", &self.next_id)
.field("pending_requests", &self.pending_requests)
.finish()
}
}
#[derive(Default)]
struct PendingRequests<W: WarningHandling = WarningsAsErrors> {
requests: HashMap<String, PendingRequest<W>>,
query_version_requests: HashMap<String, Sender<WarningResult<VersionInfo, W>>>,
clear_cache_requests: HashMap<String, Sender<WarningResult<(), W>>>,
terminate_requests: HashMap<String, Sender<WarningResult<(), W>>>,
terminate_all_requests: HashMap<String, Sender<WarningResult<(), W>>>,
query_models_requests: HashMap<String, Sender<WarningResult<Vec<Model>, W>>>,
}
impl<W: WarningHandling> PendingRequests<W> {
async fn poison_all(&mut self, error: Error) {
for (_, request) in self.requests.drain() {
for sender in request.positions.values() {
sender.send_err(error.clone()).await;
}
}
for (_, sender) in self.query_version_requests.drain() {
sender.send_err(error.clone()).await;
}
for (_, sender) in self.clear_cache_requests.drain() {
sender.send_err(error.clone()).await;
}
for (_, sender) in self.terminate_requests.drain() {
sender.send_err(error.clone()).await;
}
for (_, sender) in self.terminate_all_requests.drain() {
sender.send_err(error.clone()).await;
}
for (_, sender) in self.query_models_requests.drain() {
sender.send_err(error.clone()).await;
}
}
async fn poison(&mut self, id: &str, error: Error) {
if let Some(request) = self.requests.remove(id) {
for sender in request.positions.values() {
sender.send_err(error.clone()).await;
}
}
if let Some(sender) = self.query_version_requests.remove(id) {
sender.send_err(error.clone()).await;
}
if let Some(sender) = self.clear_cache_requests.remove(id) {
sender.send_err(error.clone()).await;
}
if let Some(sender) = self.terminate_requests.remove(id) {
sender.send_err(error.clone()).await;
}
if let Some(sender) = self.terminate_all_requests.remove(id) {
sender.send_err(error.clone()).await;
}
if let Some(sender) = self.query_models_requests.remove(id) {
sender.send_err(error.clone()).await;
}
}
async fn add_warning(&mut self, id: &str, warning: Warning) {
if let Some(request) = self.requests.get(id) {
for sender in request.positions.values() {
sender
.send_modify(|r| W::add_warning(r, warning.clone()))
.await;
}
}
if let Some(sender) = self.query_version_requests.get(id) {
sender
.send_modify(|r| W::add_warning(r, warning.clone()))
.await;
}
if let Some(sender) = self.clear_cache_requests.get(id) {
sender
.send_modify(|r| W::add_warning(r, warning.clone()))
.await;
}
if let Some(sender) = self.terminate_requests.get(id) {
sender
.send_modify(|r| W::add_warning(r, warning.clone()))
.await;
}
if let Some(sender) = self.terminate_all_requests.get(id) {
sender
.send_modify(|r| W::add_warning(r, warning.clone()))
.await;
}
if let Some(sender) = self.query_models_requests.get(id) {
sender
.send_modify(|r| W::add_warning(r, warning.clone()))
.await;
}
}
}
impl<W: WarningHandling> std::fmt::Debug for PendingRequests<W>
where
W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
W::OkType<VersionInfo>: std::fmt::Debug,
W::OkType<()>: std::fmt::Debug,
W::OkType<Vec<Model>>: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PendingRequests")
.field("requests", &self.requests)
.field("query_version_requests", &self.query_version_requests)
.field("clear_cache_requests", &self.clear_cache_requests)
.field("terminate_requests", &self.terminate_requests)
.field("terminate_all_requests", &self.terminate_all_requests)
.field("query_models_requests", &self.query_models_requests)
.finish()
}
}
struct PendingRequest<W: WarningHandling = WarningsAsErrors> {
positions: HashMap<usize, Sender<WarningResult<Option<AnalysisResult>, W>>>,
width: u8,
height: u8,
}
impl<W: WarningHandling> std::fmt::Debug for PendingRequest<W>
where
W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PendingRequest")
.field("positions", &self.positions)
.field("height", &self.height)
.finish()
}
}
#[derive(Debug)]
struct NotifyOnDrop(Arc<Notify>);
impl Drop for NotifyOnDrop {
fn drop(&mut self) {
self.0.notify_one();
}
}
impl Deref for NotifyOnDrop {
type Target = Arc<Notify>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Debug)]
struct Sender<T> {
value: Arc<RwLock<T>>,
notify: NotifyOnDrop,
}
impl<T> Sender<T> {
async fn send_modify(&self, f: impl FnOnce(&mut T)) {
f(&mut *self.value.write().await);
self.notify.notify_one();
}
}
impl<T, E> Sender<std::result::Result<T, E>> {
async fn send_err(&self, value: E) {
self.send_modify(|r| *r = Err(value)).await;
}
}
#[derive(Debug)]
struct Receiver<T> {
value: Arc<RwLock<T>>,
notify: Arc<Notify>,
}
impl<T> Receiver<T> {
async fn finish(mut self) -> T {
loop {
match self.poll().await {
ControlFlow::Break(value) => return value,
ControlFlow::Continue(s) => self = s,
};
}
}
async fn poll(self) -> ControlFlow<T, Self> {
self.notify.notified().await;
match Arc::try_unwrap(self.value) {
Ok(value) => ControlFlow::Break(value.into_inner()),
Err(arc) => ControlFlow::Continue(Self { value: arc, ..self }),
}
}
async fn read(&self) -> RwLockReadGuard<'_, T> {
self.value.read().await
}
}
fn channel<T>(value: T) -> (Sender<T>, Receiver<T>) {
let receiver = Receiver {
value: Arc::new(RwLock::new(value)),
notify: Arc::new(Notify::new()),
};
let sender = Sender {
value: receiver.value.clone(),
notify: NotifyOnDrop(receiver.notify.clone()),
};
(sender, receiver)
}
pub struct GameAnalysisProgress<W: WarningHandling = WarningsAsErrors> {
id: String,
positions: HashMap<usize, AnalysisProgress<W>>,
}
impl<W: WarningHandling> GameAnalysisProgress<W> {
pub async fn finish(self) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
let mut results = W::ok(HashMap::new());
for (position, progress) in self.into_positions().into_iter() {
let result = progress.finish().await;
results = W::merge(results, result, |mut results, result| {
if let Some(result) = result {
results.insert(position, result);
}
results
});
}
results
}
pub fn positions(&self) -> &HashMap<usize, AnalysisProgress<W>> {
&self.positions
}
pub fn positions_mut(&mut self) -> &mut HashMap<usize, AnalysisProgress<W>> {
&mut self.positions
}
pub fn into_positions(self) -> HashMap<usize, AnalysisProgress<W>> {
self.positions
}
}
impl<W: WarningHandling> std::fmt::Debug for GameAnalysisProgress<W>
where
W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GameAnalysisProgress")
.field("id", &self.id)
.field("positions", &self.positions)
.finish()
}
}
pub struct AnalysisProgress<W: WarningHandling = WarningsAsErrors> {
receiver: Receiver<WarningResult<Option<AnalysisResult>, W>>,
id: String,
turn_number: usize,
}
impl<W: WarningHandling> AnalysisProgress<W> {
pub async fn finish(self) -> WarningResult<Option<AnalysisResult>, W> {
self.receiver.finish().await
}
pub async fn poll(self) -> ControlFlow<WarningResult<Option<AnalysisResult>, W>, Self> {
self.receiver.poll().await.map_continue(|r| Self {
receiver: r,
..self
})
}
pub async fn read(&self) -> RwLockReadGuard<'_, WarningResult<Option<AnalysisResult>, W>> {
self.receiver.read().await
}
}
impl<W: WarningHandling> std::fmt::Debug for AnalysisProgress<W>
where
W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AnalysisProgress")
.field("receiver", &self.receiver)
.field("id", &self.id)
.field("turn_number", &self.turn_number)
.finish()
}
}