1use std::{
2 collections::HashMap,
3 ops::{ControlFlow, Deref},
4 sync::Arc,
5};
6
7use tokio::{
8 process::{Child, ChildStderr},
9 sync::{Notify, RwLock, RwLockReadGuard},
10};
11use tokio_stream::StreamExt;
12
13use crate::{
14 engine::{Engine, EngineStdin, EngineStdout},
15 *,
16};
17
18pub struct Analyzer<W: WarningHandling = WarningsAsErrors> {
23 stdin: EngineStdin,
24
25 pub stderr: Option<ChildStderr>,
27
28 pub child_process: Child,
30
31 next_id: u32,
32 pending_requests: Arc<RwLock<PendingRequests<W>>>,
33}
34
35impl<W: WarningHandling> Analyzer<W> {
36 pub async fn analyze(
38 &mut self,
39 request: AnalysisRequest,
40 ) -> WarningResult<Option<AnalysisResult>, W> {
41 self.start_analyze(request).await?.finish().await
42 }
43
44 pub async fn analyze_position(
46 &mut self,
47 request: AnalysisRequest,
48 position: usize,
49 ) -> WarningResult<Option<AnalysisResult>, W> {
50 self.start_analyze_position(request, position)
51 .await?
52 .finish()
53 .await
54 }
55
56 pub async fn analyze_game(
58 &mut self,
59 request: AnalysisRequest,
60 ) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
61 self.start_analyze_game(request).await?.finish().await
62 }
63
64 pub async fn analyze_positions(
66 &mut self,
67 request: AnalysisRequest,
68 analyze_turns: Vec<usize>,
69 ) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
70 self.start_analyze_positions(request, analyze_turns)
71 .await?
72 .finish()
73 .await
74 }
75
76 pub async fn start_analyze(&mut self, request: AnalysisRequest) -> Result<AnalysisProgress<W>> {
78 let position = request.moves.len();
79 self.start_analyze_position(request, position).await
80 }
81
82 pub async fn start_analyze_position(
84 &mut self,
85 request: AnalysisRequest,
86 position: usize,
87 ) -> Result<AnalysisProgress<W>> {
88 Ok(self
89 .start_analyze_positions(request, vec![position])
90 .await?
91 .into_positions()
92 .remove(&position)
93 .expect("position analysis should be available"))
94 }
95
96 pub async fn start_analyze_game(
98 &mut self,
99 request: AnalysisRequest,
100 ) -> Result<GameAnalysisProgress<W>> {
101 let positions = (0..=request.moves.len()).collect();
102 self.start_analyze_positions(request, positions).await
103 }
104
105 pub async fn start_analyze_positions(
107 &mut self,
108 request: AnalysisRequest,
109 analyze_turns: Vec<usize>,
110 ) -> Result<GameAnalysisProgress<W>> {
111 self.start_analyze_positions_impl(request, analyze_turns, None)
112 .await
113 }
114
115 pub async fn analyze_game_prioritized(
120 &mut self,
121 request: AnalysisRequest,
122 priorities: Vec<i32>,
123 ) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
124 self.start_analyze_game_prioritized(request, priorities)
125 .await?
126 .finish()
127 .await
128 }
129
130 pub async fn analyze_positions_prioritized(
135 &mut self,
136 request: AnalysisRequest,
137 analyze_turns: Vec<usize>,
138 priorities: Vec<i32>,
139 ) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
140 self.start_analyze_positions_prioritized(request, analyze_turns, priorities)
141 .await?
142 .finish()
143 .await
144 }
145
146 pub async fn start_analyze_game_prioritized(
150 &mut self,
151 request: AnalysisRequest,
152 priorities: Vec<i32>,
153 ) -> Result<GameAnalysisProgress<W>> {
154 let positions = (0..=request.moves.len()).collect();
155 self.start_analyze_positions_prioritized(request, positions, priorities)
156 .await
157 }
158
159 pub async fn start_analyze_positions_prioritized(
164 &mut self,
165 request: AnalysisRequest,
166 analyze_turns: Vec<usize>,
167 priorities: Vec<i32>,
168 ) -> Result<GameAnalysisProgress<W>> {
169 self.start_analyze_positions_impl(request, analyze_turns, Some(priorities))
170 .await
171 }
172
173 async fn start_analyze_positions_impl(
174 &mut self,
175 request: AnalysisRequest,
176 analyze_turns: Vec<usize>,
177 priorities: Option<Vec<i32>>,
178 ) -> Result<GameAnalysisProgress<W>> {
179 let id = self.generate_id();
180 let mut senders = HashMap::new();
181 let mut positions = HashMap::new();
182 for position in &analyze_turns {
183 let (sender, receiver) = channel(W::ok(None));
184 senders.insert(*position, sender);
185 positions.insert(
186 *position,
187 AnalysisProgress::<W> {
188 receiver,
189 id: id.clone(),
190 turn_number: *position,
191 },
192 );
193 }
194
195 let pending_request = PendingRequest::<W> {
196 positions: senders,
197 width: request.board_x_size,
198 height: request.board_y_size,
199 };
200
201 let mut pending = self.pending_requests.write().await;
202 self.stdin
203 .send(&engine::Request::Analyze(request.into_engine_request(
204 id.clone(),
205 analyze_turns,
206 priorities,
207 )))
208 .await?;
209 pending.requests.insert(id.clone(), pending_request);
210 Ok(GameAnalysisProgress::<W> { id, positions })
211 }
212
213 pub async fn query_version(&mut self) -> WarningResult<VersionInfo, W> {
215 let id = self.generate_id();
216 let (sender, receiver) = channel(W::ok(VersionInfo {
217 version: String::new(),
218 git_hash: String::new(),
219 }));
220
221 let mut pending = self.pending_requests.write().await;
222 self.stdin
223 .send(&engine::Request::QueryVersion { id: id.clone() })
224 .await?;
225 pending.query_version_requests.insert(id, sender);
226 drop(pending);
227
228 receiver.finish().await
229 }
230
231 pub async fn clear_cache(&mut self) -> WarningResult<(), W> {
233 let id = self.generate_id();
234 let (sender, receiver) = channel(W::ok(()));
235
236 let mut pending = self.pending_requests.write().await;
237 self.stdin
238 .send(&engine::Request::ClearCache { id: id.clone() })
239 .await?;
240 pending.clear_cache_requests.insert(id, sender);
241 drop(pending);
242
243 receiver.finish().await
244 }
245
246 pub async fn terminate(&mut self, progress: &AnalysisProgress) -> WarningResult<(), W> {
250 self.terminate_impl(progress.id.clone(), Some(vec![progress.turn_number]))
251 .await
252 }
253
254 pub async fn terminate_game(
258 &mut self,
259 progress: &GameAnalysisProgress,
260 ) -> WarningResult<(), W> {
261 self.terminate_impl(progress.id.clone(), None).await
262 }
263
264 pub async fn terminate_positions(
268 &mut self,
269 progress: &GameAnalysisProgress,
270 turn_numbers: Vec<usize>,
271 ) -> WarningResult<(), W> {
272 self.terminate_impl(progress.id.clone(), Some(turn_numbers))
273 .await
274 }
275
276 async fn terminate_impl(
277 &mut self,
278 terminate_id: String,
279 turn_numbers: Option<Vec<usize>>,
280 ) -> WarningResult<(), W> {
281 let id = self.generate_id();
282 let (sender, receiver) = channel(W::ok(()));
283
284 let mut pending = self.pending_requests.write().await;
285 self.stdin
286 .send(&engine::Request::Terminate {
287 id: id.clone(),
288 terminate_id,
289 turn_numbers,
290 })
291 .await?;
292 pending.terminate_requests.insert(id, sender);
293 drop(pending);
294
295 receiver.finish().await
296 }
297
298 pub async fn terminate_all(&mut self) -> WarningResult<(), W> {
300 self.terminate_all_impl(None).await
301 }
302
303 pub async fn terminate_all_positions(
305 &mut self,
306 turn_numbers: Vec<usize>,
307 ) -> WarningResult<(), W> {
308 self.terminate_all_impl(Some(turn_numbers)).await
309 }
310
311 async fn terminate_all_impl(
312 &mut self,
313 turn_numbers: Option<Vec<usize>>,
314 ) -> WarningResult<(), W> {
315 let id = self.generate_id();
316 let (sender, receiver) = channel(W::ok(()));
317
318 let mut pending = self.pending_requests.write().await;
319 self.stdin
320 .send(&engine::Request::TerminateAll {
321 id: id.clone(),
322 turn_numbers,
323 })
324 .await?;
325 pending.terminate_all_requests.insert(id, sender);
326 drop(pending);
327
328 receiver.finish().await
329 }
330
331 pub async fn query_models(&mut self) -> WarningResult<Vec<Model>, W> {
333 let id = self.generate_id();
334 let (sender, receiver) = channel(W::ok(vec![]));
335
336 let mut pending = self.pending_requests.write().await;
337 self.stdin
338 .send(&engine::Request::QueryModels { id: id.clone() })
339 .await?;
340 pending.query_models_requests.insert(id, sender);
341 drop(pending);
342
343 receiver.finish().await
344 }
345
346 fn generate_id(&mut self) -> String {
347 let id = self.next_id.to_string();
348 self.next_id += 1;
349 id
350 }
351}
352
353impl<W: WarningHandling + Default + Clone + 'static> From<Engine> for Analyzer<W>
354where
355 W::OkType<Option<AnalysisResult>>: Send + Sync,
356 W::OkType<VersionInfo>: Send + Sync,
357 W::OkType<()>: Send + Sync,
358 W::OkType<Vec<Model>>: Send + Sync,
359{
360 fn from(engine: Engine) -> Self {
361 let client = Self {
362 stdin: engine.stdin,
363 stderr: engine.stderr,
364 child_process: engine.child_process,
365 next_id: 1,
366 pending_requests: Arc::default(),
367 };
368
369 tokio::spawn(handle_responses(
370 engine.stdout,
371 client.pending_requests.clone(),
372 ));
373
374 client
375 }
376}
377
378async fn handle_responses<W: WarningHandling>(
379 mut stdout: EngineStdout,
380 pending: Arc<RwLock<PendingRequests<W>>>,
381) {
382 while let Some(response) = stdout.next().await {
383 let response = match response {
384 Ok(response) => response,
385 Err(e) => {
386 pending.write().await.poison_all(e).await;
387 continue;
388 }
389 };
390 match response {
391 engine::Response::Analyze(response) => {
392 let id = response.id.clone();
393 let turn_number = response.turn_number;
394 let is_during_search = response.is_during_search;
395 let mut pending = pending.write().await;
396 if let Some(request) = pending.requests.get_mut(&id) {
397 if let Some(sender) = request.positions.get(&turn_number) {
398 let result = Some(AnalysisResult::from_engine_response(
399 response,
400 request.width,
401 request.height,
402 ));
403 sender.send_modify(|r| W::set_result(r, result)).await;
404
405 if !is_during_search {
406 request.positions.remove(&turn_number);
407 }
408 }
409 if request.positions.is_empty() {
410 pending.requests.remove(&id);
411 }
412 }
413 }
414 engine::Response::NoResults { id, turn_number } => {
415 let mut pending = pending.write().await;
416 if let Some(request) = pending.requests.get_mut(&id) {
417 request.positions.remove(&turn_number);
418 if request.positions.is_empty() {
419 pending.requests.remove(&id);
420 }
421 }
422 }
423 engine::Response::QueryVersion {
424 id,
425 version,
426 git_hash,
427 } => {
428 if let Some(sender) = pending.write().await.query_version_requests.remove(&id) {
429 sender
430 .send_modify(|r| W::set_result(r, VersionInfo { version, git_hash }))
431 .await;
432 }
433 }
434 engine::Response::ClearCache { id } => {
435 if let Some(sender) = pending.write().await.clear_cache_requests.remove(&id) {
436 sender.send_modify(|r| W::set_result(r, ())).await;
437 }
438 }
439 engine::Response::Terminate { id, .. } => {
440 if let Some(sender) = pending.write().await.terminate_requests.remove(&id) {
441 sender.send_modify(|r| W::set_result(r, ())).await;
442 }
443 }
444 engine::Response::TerminateAll { id, .. } => {
445 if let Some(sender) = pending.write().await.terminate_all_requests.remove(&id) {
446 sender.send_modify(|r| W::set_result(r, ())).await;
447 }
448 }
449 engine::Response::QueryModels { id, models } => {
450 if let Some(sender) = pending.write().await.query_models_requests.remove(&id) {
451 sender.send_modify(|r| W::set_result(r, models)).await;
452 }
453 }
454 engine::Response::GeneralError { error } => {
455 pending
456 .write()
457 .await
458 .poison_all(Error::KataGoGeneralError { error })
459 .await;
460 }
461 engine::Response::FieldError { id, error, field } => {
462 pending
463 .write()
464 .await
465 .poison(&id, Error::KataGoFieldError { error, field })
466 .await;
467 }
468 engine::Response::FieldWarning { id, warning, field } => {
469 pending
470 .write()
471 .await
472 .add_warning(&id, Warning { warning, field })
473 .await;
474 }
475 };
476 }
477 let mut pending = pending.write().await;
478 pending.requests.clear();
479 pending.query_version_requests.clear();
480 pending.clear_cache_requests.clear();
481 pending.terminate_requests.clear();
482 pending.terminate_all_requests.clear();
483 pending.query_models_requests.clear();
484}
485
486impl<W: WarningHandling> std::fmt::Debug for Analyzer<W>
487where
488 W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
489 W::OkType<VersionInfo>: std::fmt::Debug,
490 W::OkType<()>: std::fmt::Debug,
491 W::OkType<Vec<Model>>: std::fmt::Debug,
492{
493 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494 f.debug_struct("Analyzer")
495 .field("stdin", &self.stdin)
496 .field("stderr", &self.stderr)
497 .field("child_process", &self.child_process)
498 .field("next_id", &self.next_id)
499 .field("pending_requests", &self.pending_requests)
500 .finish()
501 }
502}
503
504#[derive(Default)]
505struct PendingRequests<W: WarningHandling = WarningsAsErrors> {
506 requests: HashMap<String, PendingRequest<W>>,
507 query_version_requests: HashMap<String, Sender<WarningResult<VersionInfo, W>>>,
508 clear_cache_requests: HashMap<String, Sender<WarningResult<(), W>>>,
509 terminate_requests: HashMap<String, Sender<WarningResult<(), W>>>,
510 terminate_all_requests: HashMap<String, Sender<WarningResult<(), W>>>,
511 query_models_requests: HashMap<String, Sender<WarningResult<Vec<Model>, W>>>,
512}
513
514impl<W: WarningHandling> PendingRequests<W> {
515 async fn poison_all(&mut self, error: Error) {
516 for (_, request) in self.requests.drain() {
517 for sender in request.positions.values() {
518 sender.send_err(error.clone()).await;
519 }
520 }
521
522 for (_, sender) in self.query_version_requests.drain() {
523 sender.send_err(error.clone()).await;
524 }
525
526 for (_, sender) in self.clear_cache_requests.drain() {
527 sender.send_err(error.clone()).await;
528 }
529
530 for (_, sender) in self.terminate_requests.drain() {
531 sender.send_err(error.clone()).await;
532 }
533
534 for (_, sender) in self.terminate_all_requests.drain() {
535 sender.send_err(error.clone()).await;
536 }
537
538 for (_, sender) in self.query_models_requests.drain() {
539 sender.send_err(error.clone()).await;
540 }
541 }
542
543 async fn poison(&mut self, id: &str, error: Error) {
544 if let Some(request) = self.requests.remove(id) {
545 for sender in request.positions.values() {
546 sender.send_err(error.clone()).await;
547 }
548 }
549
550 if let Some(sender) = self.query_version_requests.remove(id) {
551 sender.send_err(error.clone()).await;
552 }
553
554 if let Some(sender) = self.clear_cache_requests.remove(id) {
555 sender.send_err(error.clone()).await;
556 }
557
558 if let Some(sender) = self.terminate_requests.remove(id) {
559 sender.send_err(error.clone()).await;
560 }
561
562 if let Some(sender) = self.terminate_all_requests.remove(id) {
563 sender.send_err(error.clone()).await;
564 }
565
566 if let Some(sender) = self.query_models_requests.remove(id) {
567 sender.send_err(error.clone()).await;
568 }
569 }
570
571 async fn add_warning(&mut self, id: &str, warning: Warning) {
572 if let Some(request) = self.requests.get(id) {
573 for sender in request.positions.values() {
574 sender
575 .send_modify(|r| W::add_warning(r, warning.clone()))
576 .await;
577 }
578 }
579
580 if let Some(sender) = self.query_version_requests.get(id) {
581 sender
582 .send_modify(|r| W::add_warning(r, warning.clone()))
583 .await;
584 }
585
586 if let Some(sender) = self.clear_cache_requests.get(id) {
587 sender
588 .send_modify(|r| W::add_warning(r, warning.clone()))
589 .await;
590 }
591
592 if let Some(sender) = self.terminate_requests.get(id) {
593 sender
594 .send_modify(|r| W::add_warning(r, warning.clone()))
595 .await;
596 }
597
598 if let Some(sender) = self.terminate_all_requests.get(id) {
599 sender
600 .send_modify(|r| W::add_warning(r, warning.clone()))
601 .await;
602 }
603
604 if let Some(sender) = self.query_models_requests.get(id) {
605 sender
606 .send_modify(|r| W::add_warning(r, warning.clone()))
607 .await;
608 }
609 }
610}
611
612impl<W: WarningHandling> std::fmt::Debug for PendingRequests<W>
613where
614 W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
615 W::OkType<VersionInfo>: std::fmt::Debug,
616 W::OkType<()>: std::fmt::Debug,
617 W::OkType<Vec<Model>>: std::fmt::Debug,
618{
619 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
620 f.debug_struct("PendingRequests")
621 .field("requests", &self.requests)
622 .field("query_version_requests", &self.query_version_requests)
623 .field("clear_cache_requests", &self.clear_cache_requests)
624 .field("terminate_requests", &self.terminate_requests)
625 .field("terminate_all_requests", &self.terminate_all_requests)
626 .field("query_models_requests", &self.query_models_requests)
627 .finish()
628 }
629}
630
631struct PendingRequest<W: WarningHandling = WarningsAsErrors> {
632 positions: HashMap<usize, Sender<WarningResult<Option<AnalysisResult>, W>>>,
633 width: u8,
634 height: u8,
635}
636
637impl<W: WarningHandling> std::fmt::Debug for PendingRequest<W>
638where
639 W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
640{
641 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
642 f.debug_struct("PendingRequest")
643 .field("positions", &self.positions)
644 .field("height", &self.height)
645 .finish()
646 }
647}
648
649#[derive(Debug)]
650struct NotifyOnDrop(Arc<Notify>);
651
652impl Drop for NotifyOnDrop {
653 fn drop(&mut self) {
654 self.0.notify_one();
655 }
656}
657
658impl Deref for NotifyOnDrop {
659 type Target = Arc<Notify>;
660
661 fn deref(&self) -> &Self::Target {
662 &self.0
663 }
664}
665
666#[derive(Debug)]
670struct Sender<T> {
671 value: Arc<RwLock<T>>,
672 notify: NotifyOnDrop,
673}
674
675impl<T> Sender<T> {
676 async fn send_modify(&self, f: impl FnOnce(&mut T)) {
677 f(&mut *self.value.write().await);
678 self.notify.notify_one();
679 }
680}
681
682impl<T, E> Sender<std::result::Result<T, E>> {
683 async fn send_err(&self, value: E) {
684 self.send_modify(|r| *r = Err(value)).await;
685 }
686}
687
688#[derive(Debug)]
690struct Receiver<T> {
691 value: Arc<RwLock<T>>,
692 notify: Arc<Notify>,
693}
694
695impl<T> Receiver<T> {
696 async fn finish(mut self) -> T {
697 loop {
698 match self.poll().await {
699 ControlFlow::Break(value) => return value,
700 ControlFlow::Continue(s) => self = s,
701 };
702 }
703 }
704
705 async fn poll(self) -> ControlFlow<T, Self> {
706 self.notify.notified().await;
707 match Arc::try_unwrap(self.value) {
708 Ok(value) => ControlFlow::Break(value.into_inner()),
709 Err(arc) => ControlFlow::Continue(Self { value: arc, ..self }),
710 }
711 }
712
713 async fn read(&self) -> RwLockReadGuard<'_, T> {
714 self.value.read().await
715 }
716}
717
718fn channel<T>(value: T) -> (Sender<T>, Receiver<T>) {
720 let receiver = Receiver {
721 value: Arc::new(RwLock::new(value)),
722 notify: Arc::new(Notify::new()),
723 };
724 let sender = Sender {
725 value: receiver.value.clone(),
726 notify: NotifyOnDrop(receiver.notify.clone()),
727 };
728 (sender, receiver)
729}
730
731pub struct GameAnalysisProgress<W: WarningHandling = WarningsAsErrors> {
733 id: String,
734 positions: HashMap<usize, AnalysisProgress<W>>,
735}
736
737impl<W: WarningHandling> GameAnalysisProgress<W> {
738 pub async fn finish(self) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
742 let mut results = W::ok(HashMap::new());
743 for (position, progress) in self.into_positions().into_iter() {
744 let result = progress.finish().await;
745 results = W::merge(results, result, |mut results, result| {
746 if let Some(result) = result {
747 results.insert(position, result);
748 }
749 results
750 });
751 }
752 results
753 }
754
755 pub fn positions(&self) -> &HashMap<usize, AnalysisProgress<W>> {
757 &self.positions
758 }
759
760 pub fn positions_mut(&mut self) -> &mut HashMap<usize, AnalysisProgress<W>> {
762 &mut self.positions
763 }
764
765 pub fn into_positions(self) -> HashMap<usize, AnalysisProgress<W>> {
767 self.positions
768 }
769}
770
771impl<W: WarningHandling> std::fmt::Debug for GameAnalysisProgress<W>
772where
773 W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
774{
775 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
776 f.debug_struct("GameAnalysisProgress")
777 .field("id", &self.id)
778 .field("positions", &self.positions)
779 .finish()
780 }
781}
782
783pub struct AnalysisProgress<W: WarningHandling = WarningsAsErrors> {
785 receiver: Receiver<WarningResult<Option<AnalysisResult>, W>>,
786 id: String,
787 turn_number: usize,
788}
789
790impl<W: WarningHandling> AnalysisProgress<W> {
791 pub async fn finish(self) -> WarningResult<Option<AnalysisResult>, W> {
795 self.receiver.finish().await
796 }
797
798 pub async fn poll(self) -> ControlFlow<WarningResult<Option<AnalysisResult>, W>, Self> {
844 self.receiver.poll().await.map_continue(|r| Self {
845 receiver: r,
846 ..self
847 })
848 }
849
850 pub async fn read(&self) -> RwLockReadGuard<'_, WarningResult<Option<AnalysisResult>, W>> {
854 self.receiver.read().await
855 }
856}
857
858impl<W: WarningHandling> std::fmt::Debug for AnalysisProgress<W>
859where
860 W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
861{
862 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
863 f.debug_struct("AnalysisProgress")
864 .field("receiver", &self.receiver)
865 .field("id", &self.id)
866 .field("turn_number", &self.turn_number)
867 .finish()
868 }
869}