Skip to main content

katago_analysis/
lib.rs

1//! A wrapper around KataGo's analysis protocol.
2//!
3//! See [KataGo Parallel Analysis Engine](https://github.com/lightvector/KataGo/blob/master/docs/Analysis_Engine.md)
4//! for official documentation of the analysis engine.
5//!
6//! Note: The asynchronous methods in this library must be called from within a Tokio runtime.
7//!
8//! # Examples
9//!
10//! After launching an [`Engine`](engine::Engine), the primary entry point for using this library is [`Analyzer`].
11//!
12//! ```
13//! use katago_analysis::{
14//!     AnalysisRequest, Analyzer, Coord, Move, Player, Result, Rules,
15//!     engine::{Engine, LaunchOptions},
16//! };
17//!
18//! async fn example(
19//!     katago_path: String,
20//!     analysis_config_path: String,
21//!     model_path: String,
22//! ) -> Result<()> {
23//!     let options = LaunchOptions::new(katago_path, analysis_config_path, model_path);
24//!     let mut analyzer: Analyzer = Engine::launch(&options)?.into();
25//!
26//!     let request = AnalysisRequest::new(
27//!         Rules::chinese(),
28//!         19,
29//!         19,
30//!         vec![
31//!             (Player::Black, Move::Move(Coord(15, 3))),
32//!             (Player::White, Move::Move(Coord(3, 15))),
33//!         ],
34//!     );
35//!
36//!     let results = analyzer.analyze_game(request).await?;
37//!     for i in 0..results.len() {
38//!         println!(
39//!             "Move {i}: {:.1}%",
40//!             results.get(&i).unwrap().root_info.winrate * 100.0
41//!         );
42//!     }
43//!     Ok(())
44//! }
45//! ```
46
47#![warn(missing_docs)]
48
49use std::{io, sync::Arc};
50
51use serde::{Deserialize, Serialize};
52use thiserror::Error;
53
54pub mod engine;
55
56mod analyzer;
57pub use analyzer::*;
58
59mod config;
60pub use config::*;
61
62mod request;
63pub use request::*;
64
65mod result;
66pub use result::*;
67
68mod rules;
69pub use rules::*;
70
71/// The type of results returned by methods in this library.
72pub type Result<T> = std::result::Result<T, Error>;
73
74/// The type of results which may contain warnings returned by the analysis engine.
75///
76/// See also: [`WarningHandling`]
77pub type WarningResult<T, W = WarningsAsErrors> = Result<<W as WarningHandling>::OkType<T>>;
78
79/// Errors that can occur while interacting with the analysis engine.
80#[derive(Debug, Clone, Error)]
81pub enum Error {
82    /// An I/O error occurred while launching, writing to, or reading from the analysis engine.
83    #[error("I/O error: {0}")]
84    Io(#[from] Arc<io::Error>),
85
86    /// An error occurred while serializing or deserializing a message.
87    #[error("serialization error: {0}")]
88    Serialization(#[from] Arc<serde_json::Error>),
89
90    /// The engine's stdin was unavailable after launch.
91    #[error("stdin unavailable")]
92    StdinUnavailable,
93
94    /// The engine's stdout was unavailable after launch.
95    #[error("stdout unavailable")]
96    StdoutUnavailable,
97
98    /// The specified config setting's value could not be serialized.
99    #[error("unserializable config value for key {0}")]
100    UnserializableConfig(String),
101
102    /// The analysis engine returned an error response without specifying which request caused it.
103    ///
104    /// When this error occurs, all pending requests will return this error, even if they might have otherwise
105    /// succeeded.
106    #[error("invalid request: {error}")]
107    KataGoGeneralError {
108        /// The error message provided by KataGo.
109        error: String,
110    },
111
112    /// The analysis engine returned an error response.
113    ///
114    /// When this error occurs, all positions still being analyzed as part of the associated request will return this
115    /// error, even if they might have otherwise succeeded.
116    #[error("invalid field {field}: {error}")]
117    KataGoFieldError {
118        /// The error message provided by KataGo.
119        error: String,
120
121        /// The request field that caused the error.
122        field: String,
123    },
124
125    /// The analysis engine returned a warning response which was converted to an error.
126    ///
127    /// See also: [`WarningHandling`]
128    #[error("unhandled warnings: {0:?}")]
129    UnhandledWarnings(Vec<Warning>),
130}
131
132impl From<io::Error> for Error {
133    fn from(e: io::Error) -> Self {
134        Error::Io(Arc::new(e))
135    }
136}
137
138impl From<serde_json::Error> for Error {
139    fn from(e: serde_json::Error) -> Self {
140        Error::Serialization(Arc::new(e))
141    }
142}
143
144/// A warning returned by the analysis engine.
145#[derive(Debug, Clone)]
146pub struct Warning {
147    /// The warning message provided by KataGo.
148    pub warning: String,
149
150    /// The request field that caused the warning.
151    pub field: String,
152}
153
154/// Specifies how warnings from the analysis engine should be handled.
155///
156/// # Warning Handling
157///
158/// [`Analyzer`] can handle analysis engine warnings in several ways. The `Ok` result type of methods that can produce
159/// warnings will vary depending on which strategy is chosen.
160///
161/// - [`Analyzer<WarningsAsErrors>`](WarningsAsErrors) will return the successful result directly, or return [`Error::UnhandledWarnings`]
162///   when a warning occurs. This is the default strategy.
163/// - [`Analyzer<ReturnWarnings>`](ReturnWarnings) will return a result wrapped in a [`MaybeWarnings`].
164/// - [`Analyzer<IgnoreWarnings>`](IgnoreWarnings) will return the successful result directly, and ignore any warnings that occur.
165///   This is not recommended unless you expect warnings to occur and don't intend to handle them.
166///
167/// # Examples
168///
169/// ### [`WarningsAsErrors`]
170///
171/// ```
172/// # use katago_analysis::*;
173/// async fn example(
174///     analyzer: &mut Analyzer,
175///     request: AnalysisRequest
176/// ) -> Result<()> {
177///     let result: AnalysisResult = analyzer
178///         .analyze(request)
179///         .await? // Warnings will cause this to fail with an error
180///         .unwrap();
181///     println!("{:.1}%", result.root_info.winrate * 100.0);
182///     Ok(())
183/// }
184/// ```
185///
186/// ### [`ReturnWarnings`]
187///
188/// ```
189/// # use katago_analysis::*;
190/// async fn example(
191///     analyzer: &mut Analyzer<ReturnWarnings>,
192///     request: AnalysisRequest,
193/// ) -> Result<()> {
194///     let result: AnalysisResult = analyzer
195///         .analyze(request)
196///         .await?
197///         .inspect_warnings(|warnings| {
198///             for warning in warnings {
199///                println!("{warning:?}");
200///             }
201///         })
202///         .value
203///         .unwrap();
204///     println!("{:.1}%", result.root_info.winrate * 100.0);
205///     Ok(())
206/// }
207/// ```
208///
209/// Warnings can also be converted back to errors:
210///
211/// ```
212/// # use katago_analysis::*;
213/// async fn example(
214///     analyzer: &mut Analyzer<ReturnWarnings>,
215///     request: AnalysisRequest,
216/// ) -> Result<()> {
217///     let result: AnalysisResult = analyzer
218///         .analyze(request)
219///         .await?
220///         .into_result()? // Warnings will cause this to fail with an error
221///         .unwrap();
222///     println!("{:.1}%", result.root_info.winrate * 100.0);
223///     Ok(())
224/// }
225/// ```
226///
227/// ### [`IgnoreWarnings`]
228///
229/// ```
230/// # use katago_analysis::*;
231/// async fn example(
232///     analyzer: &mut Analyzer<IgnoreWarnings>,
233///     request: AnalysisRequest
234/// ) -> Result<()> {
235///     let result: AnalysisResult = analyzer
236///         .analyze(request)
237///         .await? // Warnings will be ignored
238///         .unwrap();
239///     println!("{:.1}%", result.root_info.winrate * 100.0);
240///     Ok(())
241/// }
242/// ```
243pub trait WarningHandling {
244    /// The `Ok` result type.
245    type OkType<T>;
246
247    /// Creates a successful result containing the given value and no warnings.
248    fn ok<T>(val: T) -> WarningResult<T, Self>;
249
250    /// Updates the successful result, preserving errors and warnings.
251    fn set_result<T>(result: &mut WarningResult<T, Self>, value: T);
252
253    /// Adds a new warning to the result.
254    fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning);
255
256    /// Applies a function to the successful values of two results, merging warnings and errors.
257    fn merge<T, U, V>(
258        a: WarningResult<T, Self>,
259        b: WarningResult<U, Self>,
260        f: impl FnOnce(T, U) -> V,
261    ) -> WarningResult<V, Self>;
262}
263
264/// Warnings will return [`Error::UnhandledWarnings`].
265///
266/// See also: [`WarningHandling`]
267#[derive(Debug, Default, Clone)]
268pub struct WarningsAsErrors;
269
270impl WarningHandling for WarningsAsErrors {
271    type OkType<T> = T;
272
273    fn ok<T>(val: T) -> WarningResult<T, Self> {
274        Ok(val)
275    }
276
277    fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
278        if let Ok(r) = result {
279            *r = value;
280        }
281    }
282
283    fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning) {
284        match result {
285            Ok(_) => *result = Err(Error::UnhandledWarnings(vec![warning])),
286            Err(Error::UnhandledWarnings(warnings)) => warnings.push(warning),
287            _ => {}
288        }
289    }
290
291    fn merge<T, U, V>(
292        a: WarningResult<T, Self>,
293        b: WarningResult<U, Self>,
294        f: impl FnOnce(T, U) -> V,
295    ) -> WarningResult<V, Self> {
296        Ok(f(a?, b?))
297    }
298}
299
300/// Warnings will be returned in a [`MaybeWarnings`].
301///
302/// See also: [`WarningHandling`]
303#[derive(Debug, Default, Clone)]
304pub struct ReturnWarnings;
305
306impl WarningHandling for ReturnWarnings {
307    type OkType<T> = MaybeWarnings<T>;
308
309    fn ok<T>(val: T) -> WarningResult<T, Self> {
310        Ok(MaybeWarnings {
311            value: val,
312            warnings: None,
313        })
314    }
315
316    fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
317        if let Ok(r) = result {
318            r.value = value;
319        }
320    }
321
322    fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning) {
323        if let Ok(r) = result {
324            match r.warnings.as_mut() {
325                Some(warnings) => warnings.push(warning),
326                None => r.warnings = Some(vec![warning]),
327            }
328        }
329    }
330
331    fn merge<T, U, V>(
332        a: WarningResult<T, Self>,
333        b: WarningResult<U, Self>,
334        f: impl FnOnce(T, U) -> V,
335    ) -> WarningResult<V, Self> {
336        let MaybeWarnings {
337            value: a,
338            warnings: a_warnings,
339        } = a?;
340        let MaybeWarnings {
341            value: b,
342            warnings: b_warnings,
343        } = b?;
344        Ok(MaybeWarnings {
345            value: f(a, b),
346            warnings: a_warnings.or(b_warnings),
347        })
348    }
349}
350
351/// A result that may contain warnings.
352#[derive(Debug, Default, Clone)]
353pub struct MaybeWarnings<T> {
354    /// The successful result value.
355    pub value: T,
356
357    /// The list of warnings that occurred, if any.
358    pub warnings: Option<Vec<Warning>>,
359}
360
361impl<T> MaybeWarnings<T> {
362    /// Extracts the successful result value, returning [`Error::UnhandledWarnings`] if any warnings occurred.
363    pub fn into_result(self) -> Result<T> {
364        match self.warnings {
365            Some(warnings) => Err(Error::UnhandledWarnings(warnings)),
366            None => Ok(self.value),
367        }
368    }
369
370    /// Calls a function with a reference to the warnings, if any.
371    pub fn inspect_warnings<F: FnOnce(&Vec<Warning>)>(self, f: F) -> Self {
372        if let Some(warnings) = &self.warnings {
373            f(warnings);
374        }
375        self
376    }
377}
378
379/// Warnings will be dropped.
380///
381/// See also: [`WarningHandling`]
382#[derive(Debug, Default, Clone)]
383pub struct IgnoreWarnings;
384
385impl WarningHandling for IgnoreWarnings {
386    type OkType<T> = T;
387
388    fn ok<T>(val: T) -> WarningResult<T, Self> {
389        Ok(val)
390    }
391
392    fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
393        if let Ok(r) = result {
394            *r = value;
395        }
396    }
397
398    fn add_warning<T>(_result: &mut WarningResult<T, Self>, _warning: Warning) {}
399
400    fn merge<T, U, V>(
401        a: WarningResult<T, Self>,
402        b: WarningResult<U, Self>,
403        f: impl FnOnce(T, U) -> V,
404    ) -> WarningResult<V, Self> {
405        Ok(f(a?, b?))
406    }
407}
408
409/// Player colours.
410#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
411pub enum Player {
412    /// The black player.
413    #[serde(rename = "B")]
414    Black,
415    /// The white player.
416    #[serde(rename = "W")]
417    White,
418}
419
420#[cfg(feature = "sgf-parse")]
421impl From<sgf_parse::Color> for Player {
422    fn from(value: sgf_parse::Color) -> Self {
423        match value {
424            sgf_parse::Color::Black => Player::Black,
425            sgf_parse::Color::White => Player::White,
426        }
427    }
428}
429
430/// A board location in (x, y) format, where (0, 0) is the top-left corner of the board.
431#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
432pub struct Coord(pub u8, pub u8);
433
434impl Coord {
435    /// Converts a coordinate from GTP format.
436    pub fn from_gtp(s: &str, height: u8) -> Option<Self> {
437        let (x_part, y_part) = s
438            .chars()
439            .partition::<String, _>(|c| c.is_ascii_alphabetic());
440
441        let mut x = 0;
442        for c in x_part.chars().map(|c| c.to_ascii_uppercase()) {
443            x *= 25;
444            x += (c as u8) - b'A' + 1;
445            if c == 'I' {
446                return None;
447            } else if c > 'I' {
448                x -= 1;
449            }
450        }
451        x = x.checked_sub(1)?;
452        let y = height.checked_sub(y_part.parse::<u8>().ok()?)?;
453        Some(Self(x, y))
454    }
455
456    /// Converts a coordinate to GTP format.
457    pub fn to_gtp(self, height: u8) -> String {
458        let Self(mut x, y) = self;
459        const LETTERS: &[u8; 25] = b"ABCDEFGHJKLMNOPQRSTUVWXYZ";
460        let mut gtp = String::with_capacity(4);
461        if x >= 25 {
462            gtp.push(LETTERS[(x / 25) as usize - 1] as char);
463            x %= 25;
464        }
465        gtp.push(LETTERS[x as usize] as char);
466        gtp.push_str(&(height - y).to_string());
467        gtp
468    }
469}
470
471#[cfg(feature = "sgf-parse")]
472impl From<sgf_parse::go::Point> for Coord {
473    fn from(c: sgf_parse::go::Point) -> Self {
474        Self(c.x, c.y)
475    }
476}
477
478/// A move in a game.
479#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
480pub enum Move {
481    /// A move placing a stone at the specified coordinate.
482    Move(Coord),
483
484    /// A pass.
485    Pass,
486}
487
488impl Move {
489    /// Converts a move from GTP format.
490    pub fn from_gtp(s: &str, height: u8) -> Option<Self> {
491        if s.eq_ignore_ascii_case("pass") {
492            Some(Self::Pass)
493        } else {
494            Coord::from_gtp(s, height).map(Self::Move)
495        }
496    }
497
498    /// Converts a move to GTP format.
499    pub fn to_gtp(self, height: u8) -> String {
500        match self {
501            Move::Move(coord) => coord.to_gtp(height),
502            Move::Pass => "pass".to_string(),
503        }
504    }
505}
506
507#[cfg(feature = "sgf-parse")]
508impl From<sgf_parse::go::Move> for Move {
509    fn from(m: sgf_parse::go::Move) -> Self {
510        match m {
511            sgf_parse::go::Move::Move(p) => Move::Move(p.into()),
512            sgf_parse::go::Move::Pass => Move::Pass,
513        }
514    }
515}
516
517/// KataGo's version information.
518#[derive(Debug, Clone)]
519pub struct VersionInfo {
520    /// A string indicating the most recent KataGo release version that this version is a descendant of,
521    /// such as `"1.6.1"`.
522    pub version: String,
523
524    /// The precise git hash this KataGo version was compiled from, or the string `"<omitted>"` if KataGo was
525    /// compiled separately from its repo or without Git support.
526    pub git_hash: String,
527}
528
529/// Information about a neural network model.
530#[derive(Debug, Clone, Deserialize)]
531#[serde(rename_all = "camelCase")]
532pub struct Model {
533    /// The model name.
534    pub name: String,
535
536    /// The internal name.
537    pub internal_name: String,
538
539    /// The maximum batch size.
540    pub max_batch_size: u32,
541
542    /// Whether it uses a humanSL profile.
543    #[serde(rename = "usesHumanSLProfile")]
544    pub uses_humansl_profile: bool,
545
546    /// The model version.
547    pub version: u32,
548
549    /// Whether FP16 is used for this model. If this is [`Auto`][Enabled::Auto],
550    /// it will be enabled if the backend deems it to be beneficial.
551    #[serde(rename = "usingFP16")]
552    pub using_fp16: Enabled,
553}
554
555/// The enabled state of a feature.
556#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
557#[serde(rename_all = "lowercase")]
558pub enum Enabled {
559    /// The feature is disabled.
560    False,
561
562    /// The feature is enabled.
563    True,
564
565    /// The feature will be automatically enabled or disabled based on what the engine thinks is best.
566    Auto,
567}
568
569#[cfg(test)]
570mod tests {
571    use crate::Coord;
572
573    #[test]
574    fn coord_from_gtp() {
575        assert_eq!(Coord::from_gtp("A1", 19), Some(Coord(0, 18)));
576        assert_eq!(Coord::from_gtp("T19", 19), Some(Coord(18, 0)));
577        assert_eq!(Coord::from_gtp("I9", 19), None);
578        assert_eq!(Coord::from_gtp("1", 19), None);
579        assert_eq!(Coord::from_gtp("A", 19), None);
580        assert_eq!(Coord::from_gtp("A20", 19), None);
581        assert_eq!(Coord::from_gtp("Z1", 255), Some(Coord(24, 254)));
582        assert_eq!(Coord::from_gtp("AA1", 255), Some(Coord(25, 254)));
583        assert_eq!(Coord::from_gtp("BB1", 255), Some(Coord(51, 254)));
584        assert_eq!(Coord::from_gtp("JJ1", 255), Some(Coord(233, 254)));
585    }
586
587    #[test]
588    fn coord_to_gtp() {
589        assert_eq!(Coord(0, 18).to_gtp(19), "A1");
590        assert_eq!(Coord(18, 0).to_gtp(19), "T19");
591        assert_eq!(Coord(24, 254).to_gtp(255), "Z1");
592        assert_eq!(Coord(25, 254).to_gtp(255), "AA1");
593        assert_eq!(Coord(51, 254).to_gtp(255), "BB1");
594        assert_eq!(Coord(233, 254).to_gtp(255), "JJ1");
595    }
596}