Skip to main content

katago_analysis/
lib.rs

1//! A wrapper around KataGo's analysis protocol.
2//!
3//! See [KataGo Parallel Analysis Engine](https://github.com/lightvector/KataGo/blob/master/docs/Analysis_Engine.md)
4//! for official documentation of the analysis engine.
5//!
6//! Note: The asynchronous methods in this library must be called from within a Tokio runtime.
7//!
8//! # Examples
9//!
10//! After launching an [`Engine`](engine::Engine), the primary entry point for using this library is [`Analyzer`].
11//!
12//! ```
13//! use katago_analysis::{
14//!     AnalysisRequest, Analyzer, Coord, Move, Player, Result, Rules,
15//!     engine::{Engine, LaunchOptions},
16//! };
17//!
18//! async fn example(
19//!     katago_path: String,
20//!     analysis_config_path: String,
21//!     model_path: String,
22//! ) -> Result<()> {
23//!     let options = LaunchOptions::new(katago_path, analysis_config_path, model_path);
24//!     let mut analyzer: Analyzer = Engine::launch(&options)?.into();
25//!
26//!     let request = AnalysisRequest::new(
27//!         Rules::chinese(),
28//!         19,
29//!         19,
30//!         vec![
31//!             (Player::Black, Move::Move(Coord(15, 3))),
32//!             (Player::White, Move::Move(Coord(3, 15))),
33//!         ],
34//!     );
35//!
36//!     let results = analyzer.analyze_game(request).await?;
37//!     for i in 0..results.len() {
38//!         println!(
39//!             "Move {i}: {:.1}%",
40//!             results.get(&i).unwrap().root_info.winrate * 100.0
41//!         );
42//!     }
43//!     Ok(())
44//! }
45//! ```
46
47#![warn(missing_docs)]
48
49use std::{io, sync::Arc};
50
51use serde::{Deserialize, Serialize};
52use thiserror::Error;
53
54pub mod engine;
55
56mod analyzer;
57pub use analyzer::*;
58
59mod config;
60pub use config::*;
61
62mod request;
63pub use request::*;
64
65mod result;
66pub use result::*;
67
68mod rules;
69pub use rules::*;
70
71/// The type of results returned by methods in this library.
72pub type Result<T> = std::result::Result<T, Error>;
73
74/// The type of results which may contain warnings returned by the analysis engine.
75///
76/// See also: [`WarningHandling`]
77pub type WarningResult<T, W = WarningsAsErrors> = Result<<W as WarningHandling>::OkType<T>>;
78
79/// Errors that can occur while interacting with the analysis engine.
80#[derive(Debug, Clone, Error)]
81pub enum Error {
82    /// An I/O error occurred while launching, writing to, or reading from the analysis engine.
83    #[error("I/O error: {0}")]
84    Io(#[from] Arc<io::Error>),
85
86    /// An error occurred while serializing or deserializing a message.
87    #[error("serialization error: {0}")]
88    Serialization(#[from] Arc<serde_json::Error>),
89
90    /// The engine's stdin was unavailable after launch.
91    #[error("stdin unavailable")]
92    StdinUnavailable,
93
94    /// The engine's stdout was unavailable after launch.
95    #[error("stdout unavailable")]
96    StdoutUnavailable,
97
98    /// The specified config setting's value could not be serialized.
99    #[error("unserializable config value for key {0}")]
100    UnserializableConfig(String),
101
102    /// The analysis engine returned an error response without specifying which request caused it.
103    ///
104    /// When this error occurs, all pending requests will return this error, even if they might have otherwise
105    /// succeeded.
106    #[error("invalid request: {error}")]
107    KataGoGeneralError {
108        /// The error message provided by KataGo.
109        error: String,
110    },
111
112    /// The analysis engine returned an error response.
113    ///
114    /// When this error occurs, all positions still being analyzed as part of the associated request will return this
115    /// error, even if they might have otherwise succeeded.
116    #[error("invalid field {field}: {error}")]
117    KataGoFieldError {
118        /// The error message provided by KataGo.
119        error: String,
120
121        /// The request field that caused the error.
122        field: String,
123    },
124
125    /// The analysis engine returned a warning response which was converted to an error.
126    ///
127    /// See also: [`WarningHandling`]
128    #[error("unhandled warnings: {0:?}")]
129    UnhandledWarnings(Vec<Warning>),
130}
131
132impl From<io::Error> for Error {
133    fn from(e: io::Error) -> Self {
134        Error::Io(Arc::new(e))
135    }
136}
137
138impl From<serde_json::Error> for Error {
139    fn from(e: serde_json::Error) -> Self {
140        Error::Serialization(Arc::new(e))
141    }
142}
143
144/// A warning returned by the analysis engine.
145#[derive(Debug, Clone)]
146pub struct Warning {
147    /// The warning message provided by KataGo.
148    pub warning: String,
149
150    /// The request field that caused the warning.
151    pub field: String,
152}
153
154/// Specifies how warnings from the analysis engine should be handled.
155///
156/// # Warning Handling
157///
158/// [`Analyzer`] can handle analysis engine warnings in several ways. The `Ok` result type of methods that can produce
159/// warnings will vary depending on which strategy is chosen.
160///
161/// - [`Analyzer<WarningsAsErrors>`](WarningsAsErrors) will return the successful result directly, or return [`Error::UnhandledWarnings`]
162///   when a warning occurs. This is the default strategy.
163/// - [`Analyzer<ReturnWarnings>`](ReturnWarnings) will return a result wrapped in a [`MaybeWarnings`].
164/// - [`Analyzer<IgnoreWarnings>`](IgnoreWarnings) will return the successful result directly, and ignore any warnings that occur.
165///   This is not recommended unless you expect warnings to occur and don't intend to handle them.
166///
167/// # Examples
168///
169/// ### [`WarningsAsErrors`]
170///
171/// ```
172/// # use katago_analysis::*;
173/// async fn example(
174///     analyzer: &mut Analyzer,
175///     request: AnalysisRequest
176/// ) -> Result<()> {
177///     let result: AnalysisResult = analyzer
178///         .analyze(request)
179///         .await? // Warnings will cause this to fail with an error
180///         .unwrap();
181///     println!("{:.1}%", result.root_info.winrate * 100.0);
182///     Ok(())
183/// }
184/// ```
185///
186/// ### [`ReturnWarnings`]
187///
188/// ```
189/// # use katago_analysis::*;
190/// async fn example(
191///     analyzer: &mut Analyzer<ReturnWarnings>,
192///     request: AnalysisRequest,
193/// ) -> Result<()> {
194///     let result: AnalysisResult = analyzer
195///         .analyze(request)
196///         .await?
197///         .inspect_warnings(|warnings| {
198///             for warning in warnings {
199///                println!("{warning:?}");
200///             }
201///         })
202///         .value
203///         .unwrap();
204///     println!("{:.1}%", result.root_info.winrate * 100.0);
205///     Ok(())
206/// }
207/// ```
208///
209/// Warnings can also be converted back to errors:
210///
211/// ```
212/// # use katago_analysis::*;
213/// async fn example(
214///     analyzer: &mut Analyzer<ReturnWarnings>,
215///     request: AnalysisRequest,
216/// ) -> Result<()> {
217///     let result: AnalysisResult = analyzer
218///         .analyze(request)
219///         .await?
220///         .into_result()? // Warnings will cause this to fail with an error
221///         .unwrap();
222///     println!("{:.1}%", result.root_info.winrate * 100.0);
223///     Ok(())
224/// }
225/// ```
226///
227/// ### [`IgnoreWarnings`]
228///
229/// ```
230/// # use katago_analysis::*;
231/// async fn example(
232///     analyzer: &mut Analyzer<IgnoreWarnings>,
233///     request: AnalysisRequest
234/// ) -> Result<()> {
235///     let result: AnalysisResult = analyzer
236///         .analyze(request)
237///         .await? // Warnings will be ignored
238///         .unwrap();
239///     println!("{:.1}%", result.root_info.winrate * 100.0);
240///     Ok(())
241/// }
242/// ```
243pub trait WarningHandling {
244    /// The `Ok` result type.
245    type OkType<T>;
246
247    /// Creates a successful result containing the given value and no warnings.
248    fn ok<T>(val: T) -> WarningResult<T, Self>;
249
250    /// Updates the successful result, preserving errors and warnings.
251    fn set_result<T>(result: &mut WarningResult<T, Self>, value: T);
252
253    /// Adds a new warning to the result.
254    fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning);
255
256    /// Applies a function to the successful values of two results, merging warnings and errors.
257    fn merge<T, U, V>(
258        a: WarningResult<T, Self>,
259        b: WarningResult<U, Self>,
260        f: impl FnOnce(T, U) -> V,
261    ) -> WarningResult<V, Self>;
262}
263
264/// Warnings will return [`Error::UnhandledWarnings`].
265///
266/// See also: [`WarningHandling`]
267#[derive(Debug, Default, Clone)]
268pub struct WarningsAsErrors;
269
270impl WarningHandling for WarningsAsErrors {
271    type OkType<T> = T;
272
273    fn ok<T>(val: T) -> WarningResult<T, Self> {
274        Ok(val)
275    }
276
277    fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
278        if let Ok(r) = result {
279            *r = value;
280        }
281    }
282
283    fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning) {
284        match result {
285            Ok(_) => *result = Err(Error::UnhandledWarnings(vec![warning])),
286            Err(Error::UnhandledWarnings(warnings)) => warnings.push(warning),
287            _ => {}
288        }
289    }
290
291    fn merge<T, U, V>(
292        a: WarningResult<T, Self>,
293        b: WarningResult<U, Self>,
294        f: impl FnOnce(T, U) -> V,
295    ) -> WarningResult<V, Self> {
296        Ok(f(a?, b?))
297    }
298}
299
300/// Warnings will be returned in a [`MaybeWarnings`].
301///
302/// See also: [`WarningHandling`]
303#[derive(Debug, Default, Clone)]
304pub struct ReturnWarnings;
305
306impl WarningHandling for ReturnWarnings {
307    type OkType<T> = MaybeWarnings<T>;
308
309    fn ok<T>(val: T) -> WarningResult<T, Self> {
310        Ok(MaybeWarnings {
311            value: val,
312            warnings: None,
313        })
314    }
315
316    fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
317        if let Ok(r) = result {
318            r.value = value;
319        }
320    }
321
322    fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning) {
323        if let Ok(r) = result {
324            match r.warnings.as_mut() {
325                Some(warnings) => warnings.push(warning),
326                None => r.warnings = Some(vec![warning]),
327            }
328        }
329    }
330
331    fn merge<T, U, V>(
332        a: WarningResult<T, Self>,
333        b: WarningResult<U, Self>,
334        f: impl FnOnce(T, U) -> V,
335    ) -> WarningResult<V, Self> {
336        let MaybeWarnings {
337            value: a,
338            warnings: a_warnings,
339        } = a?;
340        let MaybeWarnings {
341            value: b,
342            warnings: b_warnings,
343        } = b?;
344        Ok(MaybeWarnings {
345            value: f(a, b),
346            warnings: a_warnings.or(b_warnings),
347        })
348    }
349}
350
351/// A result that may contain warnings.
352#[derive(Debug, Default, Clone)]
353pub struct MaybeWarnings<T> {
354    /// The successful result value.
355    pub value: T,
356
357    /// The list of warnings that occurred, if any.
358    pub warnings: Option<Vec<Warning>>,
359}
360
361impl<T> MaybeWarnings<T> {
362    /// Extracts the successful result value, returning [`Error::UnhandledWarnings`] if any warnings occurred.
363    pub fn into_result(self) -> Result<T> {
364        match self.warnings {
365            Some(warnings) => Err(Error::UnhandledWarnings(warnings)),
366            None => Ok(self.value),
367        }
368    }
369
370    /// Calls a function with a reference to the warnings, if any.
371    pub fn inspect_warnings<F: FnOnce(&Vec<Warning>)>(self, f: F) -> Self {
372        if let Some(warnings) = &self.warnings {
373            f(warnings);
374        }
375        self
376    }
377}
378
379/// Warnings will be dropped.
380///
381/// See also: [`WarningHandling`]
382#[derive(Debug, Default, Clone)]
383pub struct IgnoreWarnings;
384
385impl WarningHandling for IgnoreWarnings {
386    type OkType<T> = T;
387
388    fn ok<T>(val: T) -> WarningResult<T, Self> {
389        Ok(val)
390    }
391
392    fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
393        if let Ok(r) = result {
394            *r = value;
395        }
396    }
397
398    fn add_warning<T>(_result: &mut WarningResult<T, Self>, _warning: Warning) {}
399
400    fn merge<T, U, V>(
401        a: WarningResult<T, Self>,
402        b: WarningResult<U, Self>,
403        f: impl FnOnce(T, U) -> V,
404    ) -> WarningResult<V, Self> {
405        Ok(f(a?, b?))
406    }
407}
408
409/// Player colours.
410#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
411pub enum Player {
412    /// The black player.
413    #[serde(rename = "B")]
414    Black,
415    /// The white player.
416    #[serde(rename = "W")]
417    White,
418}
419
420/// A board location in (x, y) format, where (0, 0) is the top-left corner of the board.
421#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
422pub struct Coord(pub u8, pub u8);
423
424impl Coord {
425    /// Converts a coordinate from GTP format.
426    pub fn from_gtp(s: &str, height: u8) -> Option<Self> {
427        let (x_part, y_part) = s
428            .chars()
429            .partition::<String, _>(|c| c.is_ascii_alphabetic());
430
431        let mut x = 0;
432        for c in x_part.chars().map(|c| c.to_ascii_uppercase()) {
433            x *= 25;
434            x += (c as u8) - b'A' + 1;
435            if c == 'I' {
436                return None;
437            } else if c > 'I' {
438                x -= 1;
439            }
440        }
441        x = x.checked_sub(1)?;
442        let y = height.checked_sub(y_part.parse::<u8>().ok()?)?;
443        Some(Self(x, y))
444    }
445
446    /// Converts a coordinate to GTP format.
447    pub fn to_gtp(self, height: u8) -> String {
448        let Self(mut x, y) = self;
449        const LETTERS: &[u8; 25] = b"ABCDEFGHJKLMNOPQRSTUVWXYZ";
450        let mut gtp = String::with_capacity(4);
451        if x >= 25 {
452            gtp.push(LETTERS[(x / 25) as usize - 1] as char);
453            x %= 25;
454        }
455        gtp.push(LETTERS[x as usize] as char);
456        gtp.push_str(&(height - y).to_string());
457        gtp
458    }
459}
460
461/// A move in a game.
462#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
463pub enum Move {
464    /// A move placing a stone at the specified coordinate.
465    Move(Coord),
466
467    /// A pass.
468    Pass,
469}
470
471impl Move {
472    /// Converts a move from GTP format.
473    pub fn from_gtp(s: &str, height: u8) -> Option<Self> {
474        if s.to_ascii_lowercase() == "pass" {
475            Some(Self::Pass)
476        } else {
477            Coord::from_gtp(s, height).map(Self::Move)
478        }
479    }
480
481    /// Converts a move to GTP format.
482    pub fn to_gtp(self, height: u8) -> String {
483        match self {
484            Move::Move(coord) => coord.to_gtp(height),
485            Move::Pass => "pass".to_string(),
486        }
487    }
488}
489
490/// KataGo's version information.
491#[derive(Debug, Clone)]
492pub struct VersionInfo {
493    /// A string indicating the most recent KataGo release version that this version is a descendant of,
494    /// such as `"1.6.1"`.
495    pub version: String,
496
497    /// The precise git hash this KataGo version was compiled from, or the string `"<omitted>"` if KataGo was
498    /// compiled separately from its repo or without Git support.
499    pub git_hash: String,
500}
501
502/// Information about a neural network model.
503#[derive(Debug, Clone, Deserialize)]
504#[serde(rename_all = "camelCase")]
505pub struct Model {
506    /// The model name.
507    pub name: String,
508
509    /// The internal name.
510    pub internal_name: String,
511
512    /// The maximum batch size.
513    pub max_batch_size: u32,
514
515    /// Whether it uses a humanSL profile.
516    #[serde(rename = "usesHumanSLProfile")]
517    pub uses_humansl_profile: bool,
518
519    /// The model version.
520    pub version: u32,
521
522    /// Whether FP16 is used for this model. If this is [`Auto`][Enabled::Auto],
523    /// it will be enabled if the backend deems it to be beneficial.
524    #[serde(rename = "usingFP16")]
525    pub using_fp16: Enabled,
526}
527
528/// The enabled state of a feature.
529#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
530#[serde(rename_all = "lowercase")]
531pub enum Enabled {
532    /// The feature is disabled.
533    False,
534
535    /// The feature is enabled.
536    True,
537
538    /// The feature will be automatically enabled or disabled based on what the engine thinks is best.
539    Auto,
540}
541
542#[cfg(test)]
543mod tests {
544    use crate::Coord;
545
546    #[test]
547    fn coord_from_gtp() {
548        assert_eq!(Coord::from_gtp("A1", 19), Some(Coord(0, 18)));
549        assert_eq!(Coord::from_gtp("T19", 19), Some(Coord(18, 0)));
550        assert_eq!(Coord::from_gtp("I9", 19), None);
551        assert_eq!(Coord::from_gtp("1", 19), None);
552        assert_eq!(Coord::from_gtp("A", 19), None);
553        assert_eq!(Coord::from_gtp("A20", 19), None);
554        assert_eq!(Coord::from_gtp("Z1", 255), Some(Coord(24, 254)));
555        assert_eq!(Coord::from_gtp("AA1", 255), Some(Coord(25, 254)));
556        assert_eq!(Coord::from_gtp("BB1", 255), Some(Coord(51, 254)));
557        assert_eq!(Coord::from_gtp("JJ1", 255), Some(Coord(233, 254)));
558    }
559
560    #[test]
561    fn coord_to_gtp() {
562        assert_eq!(Coord(0, 18).to_gtp(19), "A1");
563        assert_eq!(Coord(18, 0).to_gtp(19), "T19");
564        assert_eq!(Coord(24, 254).to_gtp(255), "Z1");
565        assert_eq!(Coord(25, 254).to_gtp(255), "AA1");
566        assert_eq!(Coord(51, 254).to_gtp(255), "BB1");
567        assert_eq!(Coord(233, 254).to_gtp(255), "JJ1");
568    }
569}