Skip to main content

katago_analysis/
lib.rs

1//! A wrapper around KataGo's analysis protocol.
2//!
3//! See [KataGo Parallel Analysis Engine](https://github.com/lightvector/KataGo/blob/master/docs/Analysis_Engine.md)
4//! for official documentation of the analysis engine.
5//!
6//! Note: The asynchronous methods in this library must be called from within a Tokio runtime.
7//!
8//! # Examples
9//!
10//! After launching an [`Engine`](engine::Engine), the primary entry point for using this library is [`Analyzer`].
11//!
12//! ```
13//! use katago_analysis::{
14//!     AnalysisRequest, Analyzer, Coord, Move, Player, Result, Rules,
15//!     engine::{Engine, LaunchOptions},
16//! };
17//!
18//! async fn example(
19//!     katago_path: String,
20//!     analysis_config_path: String,
21//!     model_path: String,
22//! ) -> Result<()> {
23//!     let options = LaunchOptions::new(katago_path, analysis_config_path, model_path);
24//!     let mut analyzer: Analyzer = Engine::launch(&options)?.into();
25//!
26//!     let request = AnalysisRequest::new(
27//!         Rules::chinese(),
28//!         19,
29//!         19,
30//!         vec![
31//!             (Player::Black, Move::Move(Coord(15, 3))),
32//!             (Player::White, Move::Move(Coord(3, 15))),
33//!         ],
34//!     );
35//!
36//!     let results = analyzer.analyze_game(request).await?;
37//!     for i in 0..results.len() {
38//!         println!(
39//!             "Move {i}: {:.1}%",
40//!             results.get(&i).unwrap().root_info.winrate * 100.0
41//!         );
42//!     }
43//!     Ok(())
44//! }
45//! ```
46
47#![cfg_attr(docsrs, feature(doc_cfg))]
48#![warn(missing_docs)]
49
50use std::{io, sync::Arc};
51
52use serde::{Deserialize, Serialize};
53use thiserror::Error;
54
55pub mod engine;
56
57mod analyzer;
58pub use analyzer::*;
59
60mod config;
61pub use config::*;
62
63mod request;
64pub use request::*;
65
66mod result;
67pub use result::*;
68
69mod rules;
70pub use rules::*;
71
72/// The type of results returned by methods in this library.
73pub type Result<T> = std::result::Result<T, Error>;
74
75/// The type of results which may contain warnings returned by the analysis engine.
76///
77/// See also: [`WarningHandling`]
78pub type WarningResult<T, W = WarningsAsErrors> = Result<<W as WarningHandling>::OkType<T>>;
79
80/// Errors that can occur while interacting with the analysis engine.
81#[derive(Debug, Clone, Error)]
82pub enum Error {
83    /// An I/O error occurred while launching, writing to, or reading from the analysis engine.
84    #[error("I/O error: {0}")]
85    Io(#[from] Arc<io::Error>),
86
87    /// An error occurred while serializing or deserializing a message.
88    #[error("serialization error: {0}")]
89    Serialization(#[from] Arc<serde_json::Error>),
90
91    /// The engine's stdin was unavailable after launch.
92    #[error("stdin unavailable")]
93    StdinUnavailable,
94
95    /// The engine's stdout was unavailable after launch.
96    #[error("stdout unavailable")]
97    StdoutUnavailable,
98
99    /// The specified config setting's value could not be serialized.
100    #[error("unserializable config value for key {0}")]
101    UnserializableConfig(String),
102
103    /// The analysis engine returned an error response without specifying which request caused it.
104    ///
105    /// When this error occurs, all pending requests will return this error, even if they might have otherwise
106    /// succeeded.
107    #[error("invalid request: {error}")]
108    KataGoGeneralError {
109        /// The error message provided by KataGo.
110        error: String,
111    },
112
113    /// The analysis engine returned an error response.
114    ///
115    /// When this error occurs, all positions still being analyzed as part of the associated request will return this
116    /// error, even if they might have otherwise succeeded.
117    #[error("invalid field {field}: {error}")]
118    KataGoFieldError {
119        /// The error message provided by KataGo.
120        error: String,
121
122        /// The request field that caused the error.
123        field: String,
124    },
125
126    /// The analysis engine returned a warning response which was converted to an error.
127    ///
128    /// See also: [`WarningHandling`]
129    #[error("unhandled warnings: {0:?}")]
130    UnhandledWarnings(Vec<Warning>),
131}
132
133impl From<io::Error> for Error {
134    fn from(e: io::Error) -> Self {
135        Error::Io(Arc::new(e))
136    }
137}
138
139impl From<serde_json::Error> for Error {
140    fn from(e: serde_json::Error) -> Self {
141        Error::Serialization(Arc::new(e))
142    }
143}
144
145/// A warning returned by the analysis engine.
146#[derive(Debug, Clone)]
147pub struct Warning {
148    /// The warning message provided by KataGo.
149    pub warning: String,
150
151    /// The request field that caused the warning.
152    pub field: String,
153}
154
155/// Specifies how warnings from the analysis engine should be handled.
156///
157/// # Warning Handling
158///
159/// [`Analyzer`] can handle analysis engine warnings in several ways. The `Ok` result type of methods that can produce
160/// warnings will vary depending on which strategy is chosen.
161///
162/// - [`Analyzer<WarningsAsErrors>`](WarningsAsErrors) will return the successful result directly, or return [`Error::UnhandledWarnings`]
163///   when a warning occurs. This is the default strategy.
164/// - [`Analyzer<ReturnWarnings>`](ReturnWarnings) will return a result wrapped in a [`MaybeWarnings`].
165/// - [`Analyzer<IgnoreWarnings>`](IgnoreWarnings) will return the successful result directly, and ignore any warnings that occur.
166///   This is not recommended unless you expect warnings to occur and don't intend to handle them.
167///
168/// # Examples
169///
170/// ### [`WarningsAsErrors`]
171///
172/// ```
173/// # use katago_analysis::*;
174/// async fn example(
175///     analyzer: &mut Analyzer,
176///     request: AnalysisRequest
177/// ) -> Result<()> {
178///     let result: AnalysisResult = analyzer
179///         .analyze(request)
180///         .await? // Warnings will cause this to fail with an error
181///         .unwrap();
182///     println!("{:.1}%", result.root_info.winrate * 100.0);
183///     Ok(())
184/// }
185/// ```
186///
187/// ### [`ReturnWarnings`]
188///
189/// ```
190/// # use katago_analysis::*;
191/// async fn example(
192///     analyzer: &mut Analyzer<ReturnWarnings>,
193///     request: AnalysisRequest,
194/// ) -> Result<()> {
195///     let result: AnalysisResult = analyzer
196///         .analyze(request)
197///         .await?
198///         .inspect_warnings(|warnings| {
199///             for warning in warnings {
200///                println!("{warning:?}");
201///             }
202///         })
203///         .value
204///         .unwrap();
205///     println!("{:.1}%", result.root_info.winrate * 100.0);
206///     Ok(())
207/// }
208/// ```
209///
210/// Warnings can also be converted back to errors:
211///
212/// ```
213/// # use katago_analysis::*;
214/// async fn example(
215///     analyzer: &mut Analyzer<ReturnWarnings>,
216///     request: AnalysisRequest,
217/// ) -> Result<()> {
218///     let result: AnalysisResult = analyzer
219///         .analyze(request)
220///         .await?
221///         .into_result()? // Warnings will cause this to fail with an error
222///         .unwrap();
223///     println!("{:.1}%", result.root_info.winrate * 100.0);
224///     Ok(())
225/// }
226/// ```
227///
228/// ### [`IgnoreWarnings`]
229///
230/// ```
231/// # use katago_analysis::*;
232/// async fn example(
233///     analyzer: &mut Analyzer<IgnoreWarnings>,
234///     request: AnalysisRequest
235/// ) -> Result<()> {
236///     let result: AnalysisResult = analyzer
237///         .analyze(request)
238///         .await? // Warnings will be ignored
239///         .unwrap();
240///     println!("{:.1}%", result.root_info.winrate * 100.0);
241///     Ok(())
242/// }
243/// ```
244pub trait WarningHandling {
245    /// The `Ok` result type.
246    type OkType<T>;
247
248    /// Creates a successful result containing the given value and no warnings.
249    fn ok<T>(val: T) -> WarningResult<T, Self>;
250
251    /// Updates the successful result, preserving errors and warnings.
252    fn set_result<T>(result: &mut WarningResult<T, Self>, value: T);
253
254    /// Adds a new warning to the result.
255    fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning);
256
257    /// Applies a function to the successful values of two results, merging warnings and errors.
258    fn merge<T, U, V>(
259        a: WarningResult<T, Self>,
260        b: WarningResult<U, Self>,
261        f: impl FnOnce(T, U) -> V,
262    ) -> WarningResult<V, Self>;
263}
264
265/// Warnings will return [`Error::UnhandledWarnings`].
266///
267/// See also: [`WarningHandling`]
268#[derive(Debug, Default, Clone)]
269pub struct WarningsAsErrors;
270
271impl WarningHandling for WarningsAsErrors {
272    type OkType<T> = T;
273
274    fn ok<T>(val: T) -> WarningResult<T, Self> {
275        Ok(val)
276    }
277
278    fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
279        if let Ok(r) = result {
280            *r = value;
281        }
282    }
283
284    fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning) {
285        match result {
286            Ok(_) => *result = Err(Error::UnhandledWarnings(vec![warning])),
287            Err(Error::UnhandledWarnings(warnings)) => warnings.push(warning),
288            _ => {}
289        }
290    }
291
292    fn merge<T, U, V>(
293        a: WarningResult<T, Self>,
294        b: WarningResult<U, Self>,
295        f: impl FnOnce(T, U) -> V,
296    ) -> WarningResult<V, Self> {
297        Ok(f(a?, b?))
298    }
299}
300
301/// Warnings will be returned in a [`MaybeWarnings`].
302///
303/// See also: [`WarningHandling`]
304#[derive(Debug, Default, Clone)]
305pub struct ReturnWarnings;
306
307impl WarningHandling for ReturnWarnings {
308    type OkType<T> = MaybeWarnings<T>;
309
310    fn ok<T>(val: T) -> WarningResult<T, Self> {
311        Ok(MaybeWarnings {
312            value: val,
313            warnings: None,
314        })
315    }
316
317    fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
318        if let Ok(r) = result {
319            r.value = value;
320        }
321    }
322
323    fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning) {
324        if let Ok(r) = result {
325            match r.warnings.as_mut() {
326                Some(warnings) => warnings.push(warning),
327                None => r.warnings = Some(vec![warning]),
328            }
329        }
330    }
331
332    fn merge<T, U, V>(
333        a: WarningResult<T, Self>,
334        b: WarningResult<U, Self>,
335        f: impl FnOnce(T, U) -> V,
336    ) -> WarningResult<V, Self> {
337        let MaybeWarnings {
338            value: a,
339            warnings: a_warnings,
340        } = a?;
341        let MaybeWarnings {
342            value: b,
343            warnings: b_warnings,
344        } = b?;
345        Ok(MaybeWarnings {
346            value: f(a, b),
347            warnings: a_warnings.or(b_warnings),
348        })
349    }
350}
351
352/// A result that may contain warnings.
353#[derive(Debug, Default, Clone)]
354pub struct MaybeWarnings<T> {
355    /// The successful result value.
356    pub value: T,
357
358    /// The list of warnings that occurred, if any.
359    pub warnings: Option<Vec<Warning>>,
360}
361
362impl<T> MaybeWarnings<T> {
363    /// Extracts the successful result value, returning [`Error::UnhandledWarnings`] if any warnings occurred.
364    pub fn into_result(self) -> Result<T> {
365        match self.warnings {
366            Some(warnings) => Err(Error::UnhandledWarnings(warnings)),
367            None => Ok(self.value),
368        }
369    }
370
371    /// Calls a function with a reference to the warnings, if any.
372    pub fn inspect_warnings<F: FnOnce(&Vec<Warning>)>(self, f: F) -> Self {
373        if let Some(warnings) = &self.warnings {
374            f(warnings);
375        }
376        self
377    }
378}
379
380/// Warnings will be dropped.
381///
382/// See also: [`WarningHandling`]
383#[derive(Debug, Default, Clone)]
384pub struct IgnoreWarnings;
385
386impl WarningHandling for IgnoreWarnings {
387    type OkType<T> = T;
388
389    fn ok<T>(val: T) -> WarningResult<T, Self> {
390        Ok(val)
391    }
392
393    fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
394        if let Ok(r) = result {
395            *r = value;
396        }
397    }
398
399    fn add_warning<T>(_result: &mut WarningResult<T, Self>, _warning: Warning) {}
400
401    fn merge<T, U, V>(
402        a: WarningResult<T, Self>,
403        b: WarningResult<U, Self>,
404        f: impl FnOnce(T, U) -> V,
405    ) -> WarningResult<V, Self> {
406        Ok(f(a?, b?))
407    }
408}
409
410/// Player colours.
411#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
412pub enum Player {
413    /// The black player.
414    #[serde(rename = "B")]
415    Black,
416    /// The white player.
417    #[serde(rename = "W")]
418    White,
419}
420
421#[cfg(feature = "sgf-parse")]
422impl From<sgf_parse::Color> for Player {
423    fn from(value: sgf_parse::Color) -> Self {
424        match value {
425            sgf_parse::Color::Black => Player::Black,
426            sgf_parse::Color::White => Player::White,
427        }
428    }
429}
430
431/// A board location in (x, y) format, where (0, 0) is the top-left corner of the board.
432#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
433pub struct Coord(pub u8, pub u8);
434
435impl Coord {
436    /// Converts a coordinate from GTP format.
437    pub fn from_gtp(s: &str, height: u8) -> Option<Self> {
438        let (x_part, y_part) = s
439            .chars()
440            .partition::<String, _>(|c| c.is_ascii_alphabetic());
441
442        let mut x = 0;
443        for c in x_part.chars().map(|c| c.to_ascii_uppercase()) {
444            x *= 25;
445            x += (c as u8) - b'A' + 1;
446            if c == 'I' {
447                return None;
448            } else if c > 'I' {
449                x -= 1;
450            }
451        }
452        x = x.checked_sub(1)?;
453        let y = height.checked_sub(y_part.parse::<u8>().ok()?)?;
454        Some(Self(x, y))
455    }
456
457    /// Converts a coordinate to GTP format.
458    pub fn to_gtp(self, height: u8) -> String {
459        let Self(mut x, y) = self;
460        const LETTERS: &[u8; 25] = b"ABCDEFGHJKLMNOPQRSTUVWXYZ";
461        let mut gtp = String::with_capacity(4);
462        if x >= 25 {
463            gtp.push(LETTERS[(x / 25) as usize - 1] as char);
464            x %= 25;
465        }
466        gtp.push(LETTERS[x as usize] as char);
467        gtp.push_str(&(height - y).to_string());
468        gtp
469    }
470}
471
472#[cfg(feature = "sgf-parse")]
473impl From<sgf_parse::go::Point> for Coord {
474    fn from(c: sgf_parse::go::Point) -> Self {
475        Self(c.x, c.y)
476    }
477}
478
479/// A move in a game.
480#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
481pub enum Move {
482    /// A move placing a stone at the specified coordinate.
483    Move(Coord),
484
485    /// A pass.
486    Pass,
487}
488
489impl Move {
490    /// Converts a move from GTP format.
491    pub fn from_gtp(s: &str, height: u8) -> Option<Self> {
492        if s.eq_ignore_ascii_case("pass") {
493            Some(Self::Pass)
494        } else {
495            Coord::from_gtp(s, height).map(Self::Move)
496        }
497    }
498
499    /// Converts a move to GTP format.
500    pub fn to_gtp(self, height: u8) -> String {
501        match self {
502            Move::Move(coord) => coord.to_gtp(height),
503            Move::Pass => "pass".to_string(),
504        }
505    }
506}
507
508#[cfg(feature = "sgf-parse")]
509impl From<sgf_parse::go::Move> for Move {
510    fn from(m: sgf_parse::go::Move) -> Self {
511        match m {
512            sgf_parse::go::Move::Move(p) => Move::Move(p.into()),
513            sgf_parse::go::Move::Pass => Move::Pass,
514        }
515    }
516}
517
518/// KataGo's version information.
519#[derive(Debug, Clone)]
520pub struct VersionInfo {
521    /// A string indicating the most recent KataGo release version that this version is a descendant of,
522    /// such as `"1.6.1"`.
523    pub version: String,
524
525    /// The precise git hash this KataGo version was compiled from, or the string `"<omitted>"` if KataGo was
526    /// compiled separately from its repo or without Git support.
527    pub git_hash: String,
528}
529
530/// Information about a neural network model.
531#[derive(Debug, Clone, Deserialize)]
532#[serde(rename_all = "camelCase")]
533pub struct Model {
534    /// The model name.
535    pub name: String,
536
537    /// The internal name.
538    pub internal_name: String,
539
540    /// The maximum batch size.
541    pub max_batch_size: u32,
542
543    /// Whether it uses a humanSL profile.
544    #[serde(rename = "usesHumanSLProfile")]
545    pub uses_humansl_profile: bool,
546
547    /// The model version.
548    pub version: u32,
549
550    /// Whether FP16 is used for this model. If this is [`Auto`][Enabled::Auto],
551    /// it will be enabled if the backend deems it to be beneficial.
552    #[serde(rename = "usingFP16")]
553    pub using_fp16: Enabled,
554}
555
556/// The enabled state of a feature.
557#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
558#[serde(rename_all = "lowercase")]
559pub enum Enabled {
560    /// The feature is disabled.
561    False,
562
563    /// The feature is enabled.
564    True,
565
566    /// The feature will be automatically enabled or disabled based on what the engine thinks is best.
567    Auto,
568}
569
570#[cfg(test)]
571mod tests {
572    use crate::Coord;
573
574    #[test]
575    fn coord_from_gtp() {
576        assert_eq!(Coord::from_gtp("A1", 19), Some(Coord(0, 18)));
577        assert_eq!(Coord::from_gtp("T19", 19), Some(Coord(18, 0)));
578        assert_eq!(Coord::from_gtp("I9", 19), None);
579        assert_eq!(Coord::from_gtp("1", 19), None);
580        assert_eq!(Coord::from_gtp("A", 19), None);
581        assert_eq!(Coord::from_gtp("A20", 19), None);
582        assert_eq!(Coord::from_gtp("Z1", 255), Some(Coord(24, 254)));
583        assert_eq!(Coord::from_gtp("AA1", 255), Some(Coord(25, 254)));
584        assert_eq!(Coord::from_gtp("BB1", 255), Some(Coord(51, 254)));
585        assert_eq!(Coord::from_gtp("JJ1", 255), Some(Coord(233, 254)));
586    }
587
588    #[test]
589    fn coord_to_gtp() {
590        assert_eq!(Coord(0, 18).to_gtp(19), "A1");
591        assert_eq!(Coord(18, 0).to_gtp(19), "T19");
592        assert_eq!(Coord(24, 254).to_gtp(255), "Z1");
593        assert_eq!(Coord(25, 254).to_gtp(255), "AA1");
594        assert_eq!(Coord(51, 254).to_gtp(255), "BB1");
595        assert_eq!(Coord(233, 254).to_gtp(255), "JJ1");
596    }
597}