Skip to main content

katago_analysis/
analyzer.rs

1use std::{
2    collections::HashMap,
3    ops::{ControlFlow, Deref},
4    sync::Arc,
5};
6
7use tokio::{
8    process::{Child, ChildStderr},
9    sync::{Notify, RwLock, RwLockReadGuard},
10};
11use tokio_stream::StreamExt;
12
13use crate::{
14    engine::{Engine, EngineStdin, EngineStdout},
15    *,
16};
17
18/// An instance of the KataGo analysis engine, launched as a child process.
19///
20/// Drop this to close the engine's stdin and request KataGo to exit.
21/// Responses will continue to be processed until the engine actually exits.
22pub struct Analyzer<W: WarningHandling = WarningsAsErrors> {
23    stdin: EngineStdin,
24
25    /// The analysis engine's stderr output, if available.
26    pub stderr: Option<ChildStderr>,
27
28    /// The engine process.
29    pub child_process: Child,
30
31    next_id: u32,
32    pending_requests: Arc<RwLock<PendingRequests<W>>>,
33}
34
35impl<W: WarningHandling> Analyzer<W> {
36    /// Analyzes the final position in the game and returns a single result.
37    pub async fn analyze(
38        &mut self,
39        request: AnalysisRequest,
40    ) -> WarningResult<Option<AnalysisResult>, W> {
41        self.start_analyze(request).await?.finish().await
42    }
43
44    /// Analyzes a specific position in the game and returns a single result.
45    pub async fn analyze_position(
46        &mut self,
47        request: AnalysisRequest,
48        position: usize,
49    ) -> WarningResult<Option<AnalysisResult>, W> {
50        self.start_analyze_position(request, position)
51            .await?
52            .finish()
53            .await
54    }
55
56    /// Analyzes all moves in the game and returns a collection of results, one for each position.
57    pub async fn analyze_game(
58        &mut self,
59        request: AnalysisRequest,
60    ) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
61        self.start_analyze_game(request).await?.finish().await
62    }
63
64    /// Analyzes the specified positions in the game and returns a collection of results, one for each position.
65    pub async fn analyze_positions(
66        &mut self,
67        request: AnalysisRequest,
68        analyze_turns: Vec<usize>,
69    ) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
70        self.start_analyze_positions(request, analyze_turns)
71            .await?
72            .finish()
73            .await
74    }
75
76    /// Starts analyzing the final position in the game and returns a progress object which can be polled for updates.
77    pub async fn start_analyze(&mut self, request: AnalysisRequest) -> Result<AnalysisProgress<W>> {
78        let position = request.moves.len();
79        self.start_analyze_position(request, position).await
80    }
81
82    /// Starts analyzing a specific position in the game and returns a progress object which can be polled for updates.
83    pub async fn start_analyze_position(
84        &mut self,
85        request: AnalysisRequest,
86        position: usize,
87    ) -> Result<AnalysisProgress<W>> {
88        Ok(self
89            .start_analyze_positions(request, vec![position])
90            .await?
91            .into_positions()
92            .remove(&position)
93            .expect("position analysis should be available"))
94    }
95
96    /// Starts analyzing all moves in the game and returns a collection of progress objects.
97    pub async fn start_analyze_game(
98        &mut self,
99        request: AnalysisRequest,
100    ) -> Result<GameAnalysisProgress<W>> {
101        let positions = (0..=request.moves.len()).collect();
102        self.start_analyze_positions(request, positions).await
103    }
104
105    /// Starts analyzing the specified positions in the game and returns a collection of progress objects.
106    pub async fn start_analyze_positions(
107        &mut self,
108        request: AnalysisRequest,
109        analyze_turns: Vec<usize>,
110    ) -> Result<GameAnalysisProgress<W>> {
111        self.start_analyze_positions_impl(request, analyze_turns, None)
112            .await
113    }
114
115    /// Analyzes all moves in the game with the given priorities and returns a collection of results, one for each
116    /// position.
117    ///
118    /// `priorities` must have length equal to one more than the number of moves in the game.
119    pub async fn analyze_game_prioritized(
120        &mut self,
121        request: AnalysisRequest,
122        priorities: Vec<i32>,
123    ) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
124        self.start_analyze_game_prioritized(request, priorities)
125            .await?
126            .finish()
127            .await
128    }
129
130    /// Analyzes the specified positions in the game with the given priorities and returns a collection of results,
131    /// one for each position.
132    ///
133    /// `priorities` must have the same length as `analyze_turns`.
134    pub async fn analyze_positions_prioritized(
135        &mut self,
136        request: AnalysisRequest,
137        analyze_turns: Vec<usize>,
138        priorities: Vec<i32>,
139    ) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
140        self.start_analyze_positions_prioritized(request, analyze_turns, priorities)
141            .await?
142            .finish()
143            .await
144    }
145
146    /// Starts analyzing all moves in the game with the given priorities and returns a collection of progress objects.
147    ///
148    /// `priorities` must have length equal to one more than the number of moves in the game.
149    pub async fn start_analyze_game_prioritized(
150        &mut self,
151        request: AnalysisRequest,
152        priorities: Vec<i32>,
153    ) -> Result<GameAnalysisProgress<W>> {
154        let positions = (0..=request.moves.len()).collect();
155        self.start_analyze_positions_prioritized(request, positions, priorities)
156            .await
157    }
158
159    /// Starts analyzing the specified positions in the game with the given priorities and returns a collection of
160    /// progress objects.
161    ///
162    /// `priorities` must have the same length as `analyze_turns`.
163    pub async fn start_analyze_positions_prioritized(
164        &mut self,
165        request: AnalysisRequest,
166        analyze_turns: Vec<usize>,
167        priorities: Vec<i32>,
168    ) -> Result<GameAnalysisProgress<W>> {
169        self.start_analyze_positions_impl(request, analyze_turns, Some(priorities))
170            .await
171    }
172
173    async fn start_analyze_positions_impl(
174        &mut self,
175        request: AnalysisRequest,
176        analyze_turns: Vec<usize>,
177        priorities: Option<Vec<i32>>,
178    ) -> Result<GameAnalysisProgress<W>> {
179        let id = self.generate_id();
180        let mut senders = HashMap::new();
181        let mut positions = HashMap::new();
182        for position in &analyze_turns {
183            let (sender, receiver) = channel(W::ok(None));
184            senders.insert(*position, sender);
185            positions.insert(
186                *position,
187                AnalysisProgress::<W> {
188                    receiver,
189                    id: id.clone(),
190                    turn_number: *position,
191                },
192            );
193        }
194
195        let pending_request = PendingRequest::<W> {
196            positions: senders,
197            width: request.board_x_size,
198            height: request.board_y_size,
199        };
200
201        let mut pending = self.pending_requests.write().await;
202        self.stdin
203            .send(&engine::Request::Analyze(request.into_engine_request(
204                id.clone(),
205                analyze_turns,
206                priorities,
207            )))
208            .await?;
209        pending.requests.insert(id.clone(), pending_request);
210        Ok(GameAnalysisProgress::<W> { id, positions })
211    }
212
213    /// Requests KataGo's version information.
214    pub async fn query_version(&mut self) -> WarningResult<VersionInfo, W> {
215        let id = self.generate_id();
216        let (sender, receiver) = channel(W::ok(VersionInfo {
217            version: String::new(),
218            git_hash: String::new(),
219        }));
220
221        let mut pending = self.pending_requests.write().await;
222        self.stdin
223            .send(&engine::Request::QueryVersion { id: id.clone() })
224            .await?;
225        pending.query_version_requests.insert(id, sender);
226        drop(pending);
227
228        receiver.finish().await
229    }
230
231    /// Clears the neural network cache.
232    pub async fn clear_cache(&mut self) -> WarningResult<(), W> {
233        let id = self.generate_id();
234        let (sender, receiver) = channel(W::ok(()));
235
236        let mut pending = self.pending_requests.write().await;
237        self.stdin
238            .send(&engine::Request::ClearCache { id: id.clone() })
239            .await?;
240        pending.clear_cache_requests.insert(id, sender);
241        drop(pending);
242
243        receiver.finish().await
244    }
245
246    /// Terminates the analysis for a single position.
247    ///
248    /// `progress` may still be used to wait for the final result.
249    pub async fn terminate(&mut self, progress: &AnalysisProgress) -> WarningResult<(), W> {
250        self.terminate_impl(progress.id.clone(), Some(vec![progress.turn_number]))
251            .await
252    }
253
254    /// Terminates the analysis for all positions in a game.
255    ///
256    /// `progress` may still be used to wait for the final results.
257    pub async fn terminate_game(
258        &mut self,
259        progress: &GameAnalysisProgress,
260    ) -> WarningResult<(), W> {
261        self.terminate_impl(progress.id.clone(), None).await
262    }
263
264    /// Terminates the analysis for the specified positions in a game.
265    ///
266    /// `progress` may still be used to wait for the final results.
267    pub async fn terminate_positions(
268        &mut self,
269        progress: &GameAnalysisProgress,
270        turn_numbers: Vec<usize>,
271    ) -> WarningResult<(), W> {
272        self.terminate_impl(progress.id.clone(), Some(turn_numbers))
273            .await
274    }
275
276    async fn terminate_impl(
277        &mut self,
278        terminate_id: String,
279        turn_numbers: Option<Vec<usize>>,
280    ) -> WarningResult<(), W> {
281        let id = self.generate_id();
282        let (sender, receiver) = channel(W::ok(()));
283
284        let mut pending = self.pending_requests.write().await;
285        self.stdin
286            .send(&engine::Request::Terminate {
287                id: id.clone(),
288                terminate_id,
289                turn_numbers,
290            })
291            .await?;
292        pending.terminate_requests.insert(id, sender);
293        drop(pending);
294
295        receiver.finish().await
296    }
297
298    /// Terminates all pending analysis requests.
299    pub async fn terminate_all(&mut self) -> WarningResult<(), W> {
300        self.terminate_all_impl(None).await
301    }
302
303    /// Terminates all pending analysis requests for the specified positions.
304    pub async fn terminate_all_positions(
305        &mut self,
306        turn_numbers: Vec<usize>,
307    ) -> WarningResult<(), W> {
308        self.terminate_all_impl(Some(turn_numbers)).await
309    }
310
311    async fn terminate_all_impl(
312        &mut self,
313        turn_numbers: Option<Vec<usize>>,
314    ) -> WarningResult<(), W> {
315        let id = self.generate_id();
316        let (sender, receiver) = channel(W::ok(()));
317
318        let mut pending = self.pending_requests.write().await;
319        self.stdin
320            .send(&engine::Request::TerminateAll {
321                id: id.clone(),
322                turn_numbers,
323            })
324            .await?;
325        pending.terminate_all_requests.insert(id, sender);
326        drop(pending);
327
328        receiver.finish().await
329    }
330
331    /// Requests information about the available neural network models.
332    pub async fn query_models(&mut self) -> WarningResult<Vec<Model>, W> {
333        let id = self.generate_id();
334        let (sender, receiver) = channel(W::ok(vec![]));
335
336        let mut pending = self.pending_requests.write().await;
337        self.stdin
338            .send(&engine::Request::QueryModels { id: id.clone() })
339            .await?;
340        pending.query_models_requests.insert(id, sender);
341        drop(pending);
342
343        receiver.finish().await
344    }
345
346    fn generate_id(&mut self) -> String {
347        let id = self.next_id.to_string();
348        self.next_id += 1;
349        id
350    }
351}
352
353impl<W: WarningHandling + Default + Clone + 'static> From<Engine> for Analyzer<W>
354where
355    W::OkType<Option<AnalysisResult>>: Send + Sync,
356    W::OkType<VersionInfo>: Send + Sync,
357    W::OkType<()>: Send + Sync,
358    W::OkType<Vec<Model>>: Send + Sync,
359{
360    fn from(engine: Engine) -> Self {
361        let client = Self {
362            stdin: engine.stdin,
363            stderr: engine.stderr,
364            child_process: engine.child_process,
365            next_id: 1,
366            pending_requests: Arc::default(),
367        };
368
369        tokio::spawn(handle_responses(
370            engine.stdout,
371            client.pending_requests.clone(),
372        ));
373
374        client
375    }
376}
377
378async fn handle_responses<W: WarningHandling>(
379    mut stdout: EngineStdout,
380    pending: Arc<RwLock<PendingRequests<W>>>,
381) {
382    while let Some(response) = stdout.next().await {
383        let response = match response {
384            Ok(response) => response,
385            Err(e) => {
386                pending.write().await.poison_all(e).await;
387                continue;
388            }
389        };
390        match response {
391            engine::Response::Analyze(response) => {
392                let id = response.id.clone();
393                let turn_number = response.turn_number;
394                let is_during_search = response.is_during_search;
395                let mut pending = pending.write().await;
396                if let Some(request) = pending.requests.get_mut(&id) {
397                    if let Some(sender) = request.positions.get(&turn_number) {
398                        let result = Some(AnalysisResult::from_engine_response(
399                            response,
400                            request.width,
401                            request.height,
402                        ));
403                        sender.send_modify(|r| W::set_result(r, result)).await;
404
405                        if !is_during_search {
406                            request.positions.remove(&turn_number);
407                        }
408                    }
409                    if request.positions.is_empty() {
410                        pending.requests.remove(&id);
411                    }
412                }
413            }
414            engine::Response::NoResults { id, turn_number } => {
415                let mut pending = pending.write().await;
416                if let Some(request) = pending.requests.get_mut(&id) {
417                    request.positions.remove(&turn_number);
418                    if request.positions.is_empty() {
419                        pending.requests.remove(&id);
420                    }
421                }
422            }
423            engine::Response::QueryVersion {
424                id,
425                version,
426                git_hash,
427            } => {
428                if let Some(sender) = pending.write().await.query_version_requests.remove(&id) {
429                    sender
430                        .send_modify(|r| W::set_result(r, VersionInfo { version, git_hash }))
431                        .await;
432                }
433            }
434            engine::Response::ClearCache { id } => {
435                if let Some(sender) = pending.write().await.clear_cache_requests.remove(&id) {
436                    sender.send_modify(|r| W::set_result(r, ())).await;
437                }
438            }
439            engine::Response::Terminate { id, .. } => {
440                if let Some(sender) = pending.write().await.terminate_requests.remove(&id) {
441                    sender.send_modify(|r| W::set_result(r, ())).await;
442                }
443            }
444            engine::Response::TerminateAll { id, .. } => {
445                if let Some(sender) = pending.write().await.terminate_all_requests.remove(&id) {
446                    sender.send_modify(|r| W::set_result(r, ())).await;
447                }
448            }
449            engine::Response::QueryModels { id, models } => {
450                if let Some(sender) = pending.write().await.query_models_requests.remove(&id) {
451                    sender.send_modify(|r| W::set_result(r, models)).await;
452                }
453            }
454            engine::Response::GeneralError { error } => {
455                pending
456                    .write()
457                    .await
458                    .poison_all(Error::KataGoGeneralError { error })
459                    .await;
460            }
461            engine::Response::FieldError { id, error, field } => {
462                pending
463                    .write()
464                    .await
465                    .poison(&id, Error::KataGoFieldError { error, field })
466                    .await;
467            }
468            engine::Response::FieldWarning { id, warning, field } => {
469                pending
470                    .write()
471                    .await
472                    .add_warning(&id, Warning { warning, field })
473                    .await;
474            }
475        };
476    }
477    let mut pending = pending.write().await;
478    pending.requests.clear();
479    pending.query_version_requests.clear();
480    pending.clear_cache_requests.clear();
481    pending.terminate_requests.clear();
482    pending.terminate_all_requests.clear();
483    pending.query_models_requests.clear();
484}
485
486impl<W: WarningHandling> std::fmt::Debug for Analyzer<W>
487where
488    W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
489    W::OkType<VersionInfo>: std::fmt::Debug,
490    W::OkType<()>: std::fmt::Debug,
491    W::OkType<Vec<Model>>: std::fmt::Debug,
492{
493    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494        f.debug_struct("Analyzer")
495            .field("stdin", &self.stdin)
496            .field("stderr", &self.stderr)
497            .field("child_process", &self.child_process)
498            .field("next_id", &self.next_id)
499            .field("pending_requests", &self.pending_requests)
500            .finish()
501    }
502}
503
504#[derive(Default)]
505struct PendingRequests<W: WarningHandling = WarningsAsErrors> {
506    requests: HashMap<String, PendingRequest<W>>,
507    query_version_requests: HashMap<String, Sender<WarningResult<VersionInfo, W>>>,
508    clear_cache_requests: HashMap<String, Sender<WarningResult<(), W>>>,
509    terminate_requests: HashMap<String, Sender<WarningResult<(), W>>>,
510    terminate_all_requests: HashMap<String, Sender<WarningResult<(), W>>>,
511    query_models_requests: HashMap<String, Sender<WarningResult<Vec<Model>, W>>>,
512}
513
514impl<W: WarningHandling> PendingRequests<W> {
515    async fn poison_all(&mut self, error: Error) {
516        for (_, request) in self.requests.drain() {
517            for sender in request.positions.values() {
518                sender.send_err(error.clone()).await;
519            }
520        }
521
522        for (_, sender) in self.query_version_requests.drain() {
523            sender.send_err(error.clone()).await;
524        }
525
526        for (_, sender) in self.clear_cache_requests.drain() {
527            sender.send_err(error.clone()).await;
528        }
529
530        for (_, sender) in self.terminate_requests.drain() {
531            sender.send_err(error.clone()).await;
532        }
533
534        for (_, sender) in self.terminate_all_requests.drain() {
535            sender.send_err(error.clone()).await;
536        }
537
538        for (_, sender) in self.query_models_requests.drain() {
539            sender.send_err(error.clone()).await;
540        }
541    }
542
543    async fn poison(&mut self, id: &str, error: Error) {
544        if let Some(request) = self.requests.remove(id) {
545            for sender in request.positions.values() {
546                sender.send_err(error.clone()).await;
547            }
548        }
549
550        if let Some(sender) = self.query_version_requests.remove(id) {
551            sender.send_err(error.clone()).await;
552        }
553
554        if let Some(sender) = self.clear_cache_requests.remove(id) {
555            sender.send_err(error.clone()).await;
556        }
557
558        if let Some(sender) = self.terminate_requests.remove(id) {
559            sender.send_err(error.clone()).await;
560        }
561
562        if let Some(sender) = self.terminate_all_requests.remove(id) {
563            sender.send_err(error.clone()).await;
564        }
565
566        if let Some(sender) = self.query_models_requests.remove(id) {
567            sender.send_err(error.clone()).await;
568        }
569    }
570
571    async fn add_warning(&mut self, id: &str, warning: Warning) {
572        if let Some(request) = self.requests.get(id) {
573            for sender in request.positions.values() {
574                sender
575                    .send_modify(|r| W::add_warning(r, warning.clone()))
576                    .await;
577            }
578        }
579
580        if let Some(sender) = self.query_version_requests.get(id) {
581            sender
582                .send_modify(|r| W::add_warning(r, warning.clone()))
583                .await;
584        }
585
586        if let Some(sender) = self.clear_cache_requests.get(id) {
587            sender
588                .send_modify(|r| W::add_warning(r, warning.clone()))
589                .await;
590        }
591
592        if let Some(sender) = self.terminate_requests.get(id) {
593            sender
594                .send_modify(|r| W::add_warning(r, warning.clone()))
595                .await;
596        }
597
598        if let Some(sender) = self.terminate_all_requests.get(id) {
599            sender
600                .send_modify(|r| W::add_warning(r, warning.clone()))
601                .await;
602        }
603
604        if let Some(sender) = self.query_models_requests.get(id) {
605            sender
606                .send_modify(|r| W::add_warning(r, warning.clone()))
607                .await;
608        }
609    }
610}
611
612impl<W: WarningHandling> std::fmt::Debug for PendingRequests<W>
613where
614    W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
615    W::OkType<VersionInfo>: std::fmt::Debug,
616    W::OkType<()>: std::fmt::Debug,
617    W::OkType<Vec<Model>>: std::fmt::Debug,
618{
619    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
620        f.debug_struct("PendingRequests")
621            .field("requests", &self.requests)
622            .field("query_version_requests", &self.query_version_requests)
623            .field("clear_cache_requests", &self.clear_cache_requests)
624            .field("terminate_requests", &self.terminate_requests)
625            .field("terminate_all_requests", &self.terminate_all_requests)
626            .field("query_models_requests", &self.query_models_requests)
627            .finish()
628    }
629}
630
631struct PendingRequest<W: WarningHandling = WarningsAsErrors> {
632    positions: HashMap<usize, Sender<WarningResult<Option<AnalysisResult>, W>>>,
633    width: u8,
634    height: u8,
635}
636
637impl<W: WarningHandling> std::fmt::Debug for PendingRequest<W>
638where
639    W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
640{
641    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
642        f.debug_struct("PendingRequest")
643            .field("positions", &self.positions)
644            .field("height", &self.height)
645            .finish()
646    }
647}
648
649#[derive(Debug)]
650struct NotifyOnDrop(Arc<Notify>);
651
652impl Drop for NotifyOnDrop {
653    fn drop(&mut self) {
654        self.0.notify_one();
655    }
656}
657
658impl Deref for NotifyOnDrop {
659    type Target = Arc<Notify>;
660
661    fn deref(&self) -> &Self::Target {
662        &self.0
663    }
664}
665
666/// The sender half of a single-producer single-consumer watch channel.
667///
668/// When dropped, is guaranteed to notify the receiver after the last value is sent.
669#[derive(Debug)]
670struct Sender<T> {
671    value: Arc<RwLock<T>>,
672    notify: NotifyOnDrop,
673}
674
675impl<T> Sender<T> {
676    async fn send_modify(&self, f: impl FnOnce(&mut T)) {
677        f(&mut *self.value.write().await);
678        self.notify.notify_one();
679    }
680}
681
682impl<T, E> Sender<std::result::Result<T, E>> {
683    async fn send_err(&self, value: E) {
684        self.send_modify(|r| *r = Err(value)).await;
685    }
686}
687
688/// The receiver half of a single-producer single-consumer watch channel.
689#[derive(Debug)]
690struct Receiver<T> {
691    value: Arc<RwLock<T>>,
692    notify: Arc<Notify>,
693}
694
695impl<T> Receiver<T> {
696    async fn finish(mut self) -> T {
697        loop {
698            match self.poll().await {
699                ControlFlow::Break(value) => return value,
700                ControlFlow::Continue(s) => self = s,
701            };
702        }
703    }
704
705    async fn poll(self) -> ControlFlow<T, Self> {
706        self.notify.notified().await;
707        match Arc::try_unwrap(self.value) {
708            Ok(value) => ControlFlow::Break(value.into_inner()),
709            Err(arc) => ControlFlow::Continue(Self { value: arc, ..self }),
710        }
711    }
712
713    async fn read(&self) -> RwLockReadGuard<'_, T> {
714        self.value.read().await
715    }
716}
717
718/// Creates a single-producer single-consumer watch channel with the given initial value.
719fn channel<T>(value: T) -> (Sender<T>, Receiver<T>) {
720    let receiver = Receiver {
721        value: Arc::new(RwLock::new(value)),
722        notify: Arc::new(Notify::new()),
723    };
724    let sender = Sender {
725        value: receiver.value.clone(),
726        notify: NotifyOnDrop(receiver.notify.clone()),
727    };
728    (sender, receiver)
729}
730
731/// A collection of in-progress analysis operations for multiple positions in a single game.
732pub struct GameAnalysisProgress<W: WarningHandling = WarningsAsErrors> {
733    id: String,
734    positions: HashMap<usize, AnalysisProgress<W>>,
735}
736
737impl<W: WarningHandling> GameAnalysisProgress<W> {
738    /// Waits for all positions to finish analyzing and returns the results.
739    ///
740    /// Positions that were terminated before any search was performed will not be included in the results.
741    pub async fn finish(self) -> WarningResult<HashMap<usize, AnalysisResult>, W> {
742        let mut results = W::ok(HashMap::new());
743        for (position, progress) in self.into_positions().into_iter() {
744            let result = progress.finish().await;
745            results = W::merge(results, result, |mut results, result| {
746                if let Some(result) = result {
747                    results.insert(position, result);
748                }
749                results
750            });
751        }
752        results
753    }
754
755    /// Returns a reference to the raw collection of in-progress analysis operations for each position.
756    pub fn positions(&self) -> &HashMap<usize, AnalysisProgress<W>> {
757        &self.positions
758    }
759
760    /// Returns a mutable reference to the raw collection of in-progress analysis operations for each position.
761    pub fn positions_mut(&mut self) -> &mut HashMap<usize, AnalysisProgress<W>> {
762        &mut self.positions
763    }
764
765    /// Extracts the collection of in-progress analysis operations for each position and consumes this object.
766    pub fn into_positions(self) -> HashMap<usize, AnalysisProgress<W>> {
767        self.positions
768    }
769}
770
771impl<W: WarningHandling> std::fmt::Debug for GameAnalysisProgress<W>
772where
773    W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
774{
775    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
776        f.debug_struct("GameAnalysisProgress")
777            .field("id", &self.id)
778            .field("positions", &self.positions)
779            .finish()
780    }
781}
782
783/// An in-progress analysis operation for a single position.
784pub struct AnalysisProgress<W: WarningHandling = WarningsAsErrors> {
785    receiver: Receiver<WarningResult<Option<AnalysisResult>, W>>,
786    id: String,
787    turn_number: usize,
788}
789
790impl<W: WarningHandling> AnalysisProgress<W> {
791    /// Waits for the analysis to finish and returns the result.
792    ///
793    /// If the analysis was terminated before any search was performed, returns `Ok(None)`.
794    pub async fn finish(self) -> WarningResult<Option<AnalysisResult>, W> {
795        self.receiver.finish().await
796    }
797
798    /// Waits for an analysis update.
799    ///
800    /// This is mainly useful when using [`AnalysisRequest::report_during_search_every`]. Otherwise, it's simpler to
801    /// just call [`finish`](Self::finish) to wait for the final result.
802    ///
803    /// If the analysis is finished, it consumes this object and returns [`ControlFlow::Break`] containing the final
804    /// result. If a partial result is available, it returns [`ControlFlow::Continue`] containing this object again,
805    /// which can be read using [`read`](Self::read) or polled again for the next update.
806    ///
807    /// This method (in combination with [`read`](Self::read)) is conceptually similar to Tokio's
808    /// [`watch`](tokio::sync::watch) channel. It provides a way to follow the latest information as it becomes
809    /// available without any danger of falling behind.
810    ///
811    /// # Example
812    ///
813    /// ```
814    /// # use katago_analysis::*;
815    /// # use std::ops::ControlFlow;
816    /// # async fn example(mut progress: AnalysisProgress) {
817    /// loop {
818    ///     match progress.poll().await {
819    ///         ControlFlow::Break(result) => {
820    ///             match result {
821    ///                 Ok(Some(result)) => {
822    ///                     println!("Winrate: {:.1}%", result.root_info.winrate * 100.0);
823    ///                 }
824    ///                 Ok(None) => println!("No results"),
825    ///                 Err(e) => println!("Error: {e}"),
826    ///             }
827    ///             break;
828    ///         }
829    ///         ControlFlow::Continue(p) => {
830    ///             progress = p;
831    ///             if let Ok(Some(result)) = progress.read().await.as_ref() {
832    ///                 println!(
833    ///                     "Winrate: {:.1}% Visits: {}",
834    ///                     result.root_info.winrate * 100.0,
835    ///                     result.root_info.visits
836    ///                 );
837    ///             }
838    ///         }
839    ///     };
840    /// }
841    /// # }
842    /// ```
843    pub async fn poll(self) -> ControlFlow<WarningResult<Option<AnalysisResult>, W>, Self> {
844        self.receiver.poll().await.map_continue(|r| Self {
845            receiver: r,
846            ..self
847        })
848    }
849
850    /// Reads the latest analysis result available.
851    ///
852    /// See also: [`poll`](Self::poll)
853    pub async fn read(&self) -> RwLockReadGuard<'_, WarningResult<Option<AnalysisResult>, W>> {
854        self.receiver.read().await
855    }
856}
857
858impl<W: WarningHandling> std::fmt::Debug for AnalysisProgress<W>
859where
860    W::OkType<Option<AnalysisResult>>: std::fmt::Debug,
861{
862    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
863        f.debug_struct("AnalysisProgress")
864            .field("receiver", &self.receiver)
865            .field("id", &self.id)
866            .field("turn_number", &self.turn_number)
867            .finish()
868    }
869}