katago_analysis/lib.rs
1//! A wrapper around KataGo's analysis protocol.
2//!
3//! See [KataGo Parallel Analysis Engine](https://github.com/lightvector/KataGo/blob/master/docs/Analysis_Engine.md)
4//! for official documentation of the analysis engine.
5//!
6//! Note: The asynchronous methods in this library must be called from within a Tokio runtime.
7//!
8//! # Examples
9//!
10//! After launching an [`Engine`](engine::Engine), the primary entry point for using this library is [`Analyzer`].
11//!
12//! ```
13//! use katago_analysis::{
14//! AnalysisRequest, Analyzer, Coord, Move, Player, Result, Rules,
15//! engine::{Engine, LaunchOptions},
16//! };
17//!
18//! async fn example(
19//! katago_path: String,
20//! analysis_config_path: String,
21//! model_path: String,
22//! ) -> Result<()> {
23//! let options = LaunchOptions::new(katago_path, analysis_config_path, model_path);
24//! let mut analyzer: Analyzer = Engine::launch(&options)?.into();
25//!
26//! let request = AnalysisRequest::new(
27//! Rules::chinese(),
28//! 19,
29//! 19,
30//! vec![
31//! (Player::Black, Move::Move(Coord(15, 3))),
32//! (Player::White, Move::Move(Coord(3, 15))),
33//! ],
34//! );
35//!
36//! let results = analyzer.analyze_game(request).await?;
37//! for i in 0..results.len() {
38//! println!(
39//! "Move {i}: {:.1}%",
40//! results.get(&i).unwrap().root_info.winrate * 100.0
41//! );
42//! }
43//! Ok(())
44//! }
45//! ```
46
47#![warn(missing_docs)]
48
49use std::{io, sync::Arc};
50
51use serde::{Deserialize, Serialize};
52use thiserror::Error;
53
54pub mod engine;
55
56mod analyzer;
57pub use analyzer::*;
58
59mod config;
60pub use config::*;
61
62mod request;
63pub use request::*;
64
65mod result;
66pub use result::*;
67
68mod rules;
69pub use rules::*;
70
71/// The type of results returned by methods in this library.
72pub type Result<T> = std::result::Result<T, Error>;
73
74/// The type of results which may contain warnings returned by the analysis engine.
75///
76/// See also: [`WarningHandling`]
77pub type WarningResult<T, W = WarningsAsErrors> = Result<<W as WarningHandling>::OkType<T>>;
78
79/// Errors that can occur while interacting with the analysis engine.
80#[derive(Debug, Clone, Error)]
81pub enum Error {
82 /// An I/O error occurred while launching, writing to, or reading from the analysis engine.
83 #[error("I/O error: {0}")]
84 Io(#[from] Arc<io::Error>),
85
86 /// An error occurred while serializing or deserializing a message.
87 #[error("serialization error: {0}")]
88 Serialization(#[from] Arc<serde_json::Error>),
89
90 /// The engine's stdin was unavailable after launch.
91 #[error("stdin unavailable")]
92 StdinUnavailable,
93
94 /// The engine's stdout was unavailable after launch.
95 #[error("stdout unavailable")]
96 StdoutUnavailable,
97
98 /// The specified config setting's value could not be serialized.
99 #[error("unserializable config value for key {0}")]
100 UnserializableConfig(String),
101
102 /// The analysis engine returned an error response without specifying which request caused it.
103 ///
104 /// When this error occurs, all pending requests will return this error, even if they might have otherwise
105 /// succeeded.
106 #[error("invalid request: {error}")]
107 KataGoGeneralError {
108 /// The error message provided by KataGo.
109 error: String,
110 },
111
112 /// The analysis engine returned an error response.
113 ///
114 /// When this error occurs, all positions still being analyzed as part of the associated request will return this
115 /// error, even if they might have otherwise succeeded.
116 #[error("invalid field {field}: {error}")]
117 KataGoFieldError {
118 /// The error message provided by KataGo.
119 error: String,
120
121 /// The request field that caused the error.
122 field: String,
123 },
124
125 /// The analysis engine returned a warning response which was converted to an error.
126 ///
127 /// See also: [`WarningHandling`]
128 #[error("unhandled warnings: {0:?}")]
129 UnhandledWarnings(Vec<Warning>),
130}
131
132impl From<io::Error> for Error {
133 fn from(e: io::Error) -> Self {
134 Error::Io(Arc::new(e))
135 }
136}
137
138impl From<serde_json::Error> for Error {
139 fn from(e: serde_json::Error) -> Self {
140 Error::Serialization(Arc::new(e))
141 }
142}
143
144/// A warning returned by the analysis engine.
145#[derive(Debug, Clone)]
146pub struct Warning {
147 /// The warning message provided by KataGo.
148 pub warning: String,
149
150 /// The request field that caused the warning.
151 pub field: String,
152}
153
154/// Specifies how warnings from the analysis engine should be handled.
155///
156/// # Warning Handling
157///
158/// [`Analyzer`] can handle analysis engine warnings in several ways. The `Ok` result type of methods that can produce
159/// warnings will vary depending on which strategy is chosen.
160///
161/// - [`Analyzer<WarningsAsErrors>`](WarningsAsErrors) will return the successful result directly, or return [`Error::UnhandledWarnings`]
162/// when a warning occurs. This is the default strategy.
163/// - [`Analyzer<ReturnWarnings>`](ReturnWarnings) will return a result wrapped in a [`MaybeWarnings`].
164/// - [`Analyzer<IgnoreWarnings>`](IgnoreWarnings) will return the successful result directly, and ignore any warnings that occur.
165/// This is not recommended unless you expect warnings to occur and don't intend to handle them.
166///
167/// # Examples
168///
169/// ### [`WarningsAsErrors`]
170///
171/// ```
172/// # use katago_analysis::*;
173/// async fn example(
174/// analyzer: &mut Analyzer,
175/// request: AnalysisRequest
176/// ) -> Result<()> {
177/// let result: AnalysisResult = analyzer
178/// .analyze(request)
179/// .await? // Warnings will cause this to fail with an error
180/// .unwrap();
181/// println!("{:.1}%", result.root_info.winrate * 100.0);
182/// Ok(())
183/// }
184/// ```
185///
186/// ### [`ReturnWarnings`]
187///
188/// ```
189/// # use katago_analysis::*;
190/// async fn example(
191/// analyzer: &mut Analyzer<ReturnWarnings>,
192/// request: AnalysisRequest,
193/// ) -> Result<()> {
194/// let result: AnalysisResult = analyzer
195/// .analyze(request)
196/// .await?
197/// .inspect_warnings(|warnings| {
198/// for warning in warnings {
199/// println!("{warning:?}");
200/// }
201/// })
202/// .value
203/// .unwrap();
204/// println!("{:.1}%", result.root_info.winrate * 100.0);
205/// Ok(())
206/// }
207/// ```
208///
209/// Warnings can also be converted back to errors:
210///
211/// ```
212/// # use katago_analysis::*;
213/// async fn example(
214/// analyzer: &mut Analyzer<ReturnWarnings>,
215/// request: AnalysisRequest,
216/// ) -> Result<()> {
217/// let result: AnalysisResult = analyzer
218/// .analyze(request)
219/// .await?
220/// .into_result()? // Warnings will cause this to fail with an error
221/// .unwrap();
222/// println!("{:.1}%", result.root_info.winrate * 100.0);
223/// Ok(())
224/// }
225/// ```
226///
227/// ### [`IgnoreWarnings`]
228///
229/// ```
230/// # use katago_analysis::*;
231/// async fn example(
232/// analyzer: &mut Analyzer<IgnoreWarnings>,
233/// request: AnalysisRequest
234/// ) -> Result<()> {
235/// let result: AnalysisResult = analyzer
236/// .analyze(request)
237/// .await? // Warnings will be ignored
238/// .unwrap();
239/// println!("{:.1}%", result.root_info.winrate * 100.0);
240/// Ok(())
241/// }
242/// ```
243pub trait WarningHandling {
244 /// The `Ok` result type.
245 type OkType<T>;
246
247 /// Creates a successful result containing the given value and no warnings.
248 fn ok<T>(val: T) -> WarningResult<T, Self>;
249
250 /// Updates the successful result, preserving errors and warnings.
251 fn set_result<T>(result: &mut WarningResult<T, Self>, value: T);
252
253 /// Adds a new warning to the result.
254 fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning);
255
256 /// Applies a function to the successful values of two results, merging warnings and errors.
257 fn merge<T, U, V>(
258 a: WarningResult<T, Self>,
259 b: WarningResult<U, Self>,
260 f: impl FnOnce(T, U) -> V,
261 ) -> WarningResult<V, Self>;
262}
263
264/// Warnings will return [`Error::UnhandledWarnings`].
265///
266/// See also: [`WarningHandling`]
267#[derive(Debug, Default, Clone)]
268pub struct WarningsAsErrors;
269
270impl WarningHandling for WarningsAsErrors {
271 type OkType<T> = T;
272
273 fn ok<T>(val: T) -> WarningResult<T, Self> {
274 Ok(val)
275 }
276
277 fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
278 if let Ok(r) = result {
279 *r = value;
280 }
281 }
282
283 fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning) {
284 match result {
285 Ok(_) => *result = Err(Error::UnhandledWarnings(vec![warning])),
286 Err(Error::UnhandledWarnings(warnings)) => warnings.push(warning),
287 _ => {}
288 }
289 }
290
291 fn merge<T, U, V>(
292 a: WarningResult<T, Self>,
293 b: WarningResult<U, Self>,
294 f: impl FnOnce(T, U) -> V,
295 ) -> WarningResult<V, Self> {
296 Ok(f(a?, b?))
297 }
298}
299
300/// Warnings will be returned in a [`MaybeWarnings`].
301///
302/// See also: [`WarningHandling`]
303#[derive(Debug, Default, Clone)]
304pub struct ReturnWarnings;
305
306impl WarningHandling for ReturnWarnings {
307 type OkType<T> = MaybeWarnings<T>;
308
309 fn ok<T>(val: T) -> WarningResult<T, Self> {
310 Ok(MaybeWarnings {
311 value: val,
312 warnings: None,
313 })
314 }
315
316 fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
317 if let Ok(r) = result {
318 r.value = value;
319 }
320 }
321
322 fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning) {
323 if let Ok(r) = result {
324 match r.warnings.as_mut() {
325 Some(warnings) => warnings.push(warning),
326 None => r.warnings = Some(vec![warning]),
327 }
328 }
329 }
330
331 fn merge<T, U, V>(
332 a: WarningResult<T, Self>,
333 b: WarningResult<U, Self>,
334 f: impl FnOnce(T, U) -> V,
335 ) -> WarningResult<V, Self> {
336 let MaybeWarnings {
337 value: a,
338 warnings: a_warnings,
339 } = a?;
340 let MaybeWarnings {
341 value: b,
342 warnings: b_warnings,
343 } = b?;
344 Ok(MaybeWarnings {
345 value: f(a, b),
346 warnings: a_warnings.or(b_warnings),
347 })
348 }
349}
350
351/// A result that may contain warnings.
352#[derive(Debug, Default, Clone)]
353pub struct MaybeWarnings<T> {
354 /// The successful result value.
355 pub value: T,
356
357 /// The list of warnings that occurred, if any.
358 pub warnings: Option<Vec<Warning>>,
359}
360
361impl<T> MaybeWarnings<T> {
362 /// Extracts the successful result value, returning [`Error::UnhandledWarnings`] if any warnings occurred.
363 pub fn into_result(self) -> Result<T> {
364 match self.warnings {
365 Some(warnings) => Err(Error::UnhandledWarnings(warnings)),
366 None => Ok(self.value),
367 }
368 }
369
370 /// Calls a function with a reference to the warnings, if any.
371 pub fn inspect_warnings<F: FnOnce(&Vec<Warning>)>(self, f: F) -> Self {
372 if let Some(warnings) = &self.warnings {
373 f(warnings);
374 }
375 self
376 }
377}
378
379/// Warnings will be dropped.
380///
381/// See also: [`WarningHandling`]
382#[derive(Debug, Default, Clone)]
383pub struct IgnoreWarnings;
384
385impl WarningHandling for IgnoreWarnings {
386 type OkType<T> = T;
387
388 fn ok<T>(val: T) -> WarningResult<T, Self> {
389 Ok(val)
390 }
391
392 fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
393 if let Ok(r) = result {
394 *r = value;
395 }
396 }
397
398 fn add_warning<T>(_result: &mut WarningResult<T, Self>, _warning: Warning) {}
399
400 fn merge<T, U, V>(
401 a: WarningResult<T, Self>,
402 b: WarningResult<U, Self>,
403 f: impl FnOnce(T, U) -> V,
404 ) -> WarningResult<V, Self> {
405 Ok(f(a?, b?))
406 }
407}
408
409/// Player colours.
410#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
411pub enum Player {
412 /// The black player.
413 #[serde(rename = "B")]
414 Black,
415 /// The white player.
416 #[serde(rename = "W")]
417 White,
418}
419
420#[cfg(feature = "sgf-parse")]
421impl From<sgf_parse::Color> for Player {
422 fn from(value: sgf_parse::Color) -> Self {
423 match value {
424 sgf_parse::Color::Black => Player::Black,
425 sgf_parse::Color::White => Player::White,
426 }
427 }
428}
429
430/// A board location in (x, y) format, where (0, 0) is the top-left corner of the board.
431#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
432pub struct Coord(pub u8, pub u8);
433
434impl Coord {
435 /// Converts a coordinate from GTP format.
436 pub fn from_gtp(s: &str, height: u8) -> Option<Self> {
437 let (x_part, y_part) = s
438 .chars()
439 .partition::<String, _>(|c| c.is_ascii_alphabetic());
440
441 let mut x = 0;
442 for c in x_part.chars().map(|c| c.to_ascii_uppercase()) {
443 x *= 25;
444 x += (c as u8) - b'A' + 1;
445 if c == 'I' {
446 return None;
447 } else if c > 'I' {
448 x -= 1;
449 }
450 }
451 x = x.checked_sub(1)?;
452 let y = height.checked_sub(y_part.parse::<u8>().ok()?)?;
453 Some(Self(x, y))
454 }
455
456 /// Converts a coordinate to GTP format.
457 pub fn to_gtp(self, height: u8) -> String {
458 let Self(mut x, y) = self;
459 const LETTERS: &[u8; 25] = b"ABCDEFGHJKLMNOPQRSTUVWXYZ";
460 let mut gtp = String::with_capacity(4);
461 if x >= 25 {
462 gtp.push(LETTERS[(x / 25) as usize - 1] as char);
463 x %= 25;
464 }
465 gtp.push(LETTERS[x as usize] as char);
466 gtp.push_str(&(height - y).to_string());
467 gtp
468 }
469}
470
471#[cfg(feature = "sgf-parse")]
472impl From<sgf_parse::go::Point> for Coord {
473 fn from(c: sgf_parse::go::Point) -> Self {
474 Self(c.x, c.y)
475 }
476}
477
478/// A move in a game.
479#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
480pub enum Move {
481 /// A move placing a stone at the specified coordinate.
482 Move(Coord),
483
484 /// A pass.
485 Pass,
486}
487
488impl Move {
489 /// Converts a move from GTP format.
490 pub fn from_gtp(s: &str, height: u8) -> Option<Self> {
491 if s.eq_ignore_ascii_case("pass") {
492 Some(Self::Pass)
493 } else {
494 Coord::from_gtp(s, height).map(Self::Move)
495 }
496 }
497
498 /// Converts a move to GTP format.
499 pub fn to_gtp(self, height: u8) -> String {
500 match self {
501 Move::Move(coord) => coord.to_gtp(height),
502 Move::Pass => "pass".to_string(),
503 }
504 }
505}
506
507#[cfg(feature = "sgf-parse")]
508impl From<sgf_parse::go::Move> for Move {
509 fn from(m: sgf_parse::go::Move) -> Self {
510 match m {
511 sgf_parse::go::Move::Move(p) => Move::Move(p.into()),
512 sgf_parse::go::Move::Pass => Move::Pass,
513 }
514 }
515}
516
517/// KataGo's version information.
518#[derive(Debug, Clone)]
519pub struct VersionInfo {
520 /// A string indicating the most recent KataGo release version that this version is a descendant of,
521 /// such as `"1.6.1"`.
522 pub version: String,
523
524 /// The precise git hash this KataGo version was compiled from, or the string `"<omitted>"` if KataGo was
525 /// compiled separately from its repo or without Git support.
526 pub git_hash: String,
527}
528
529/// Information about a neural network model.
530#[derive(Debug, Clone, Deserialize)]
531#[serde(rename_all = "camelCase")]
532pub struct Model {
533 /// The model name.
534 pub name: String,
535
536 /// The internal name.
537 pub internal_name: String,
538
539 /// The maximum batch size.
540 pub max_batch_size: u32,
541
542 /// Whether it uses a humanSL profile.
543 #[serde(rename = "usesHumanSLProfile")]
544 pub uses_humansl_profile: bool,
545
546 /// The model version.
547 pub version: u32,
548
549 /// Whether FP16 is used for this model. If this is [`Auto`][Enabled::Auto],
550 /// it will be enabled if the backend deems it to be beneficial.
551 #[serde(rename = "usingFP16")]
552 pub using_fp16: Enabled,
553}
554
555/// The enabled state of a feature.
556#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
557#[serde(rename_all = "lowercase")]
558pub enum Enabled {
559 /// The feature is disabled.
560 False,
561
562 /// The feature is enabled.
563 True,
564
565 /// The feature will be automatically enabled or disabled based on what the engine thinks is best.
566 Auto,
567}
568
569#[cfg(test)]
570mod tests {
571 use crate::Coord;
572
573 #[test]
574 fn coord_from_gtp() {
575 assert_eq!(Coord::from_gtp("A1", 19), Some(Coord(0, 18)));
576 assert_eq!(Coord::from_gtp("T19", 19), Some(Coord(18, 0)));
577 assert_eq!(Coord::from_gtp("I9", 19), None);
578 assert_eq!(Coord::from_gtp("1", 19), None);
579 assert_eq!(Coord::from_gtp("A", 19), None);
580 assert_eq!(Coord::from_gtp("A20", 19), None);
581 assert_eq!(Coord::from_gtp("Z1", 255), Some(Coord(24, 254)));
582 assert_eq!(Coord::from_gtp("AA1", 255), Some(Coord(25, 254)));
583 assert_eq!(Coord::from_gtp("BB1", 255), Some(Coord(51, 254)));
584 assert_eq!(Coord::from_gtp("JJ1", 255), Some(Coord(233, 254)));
585 }
586
587 #[test]
588 fn coord_to_gtp() {
589 assert_eq!(Coord(0, 18).to_gtp(19), "A1");
590 assert_eq!(Coord(18, 0).to_gtp(19), "T19");
591 assert_eq!(Coord(24, 254).to_gtp(255), "Z1");
592 assert_eq!(Coord(25, 254).to_gtp(255), "AA1");
593 assert_eq!(Coord(51, 254).to_gtp(255), "BB1");
594 assert_eq!(Coord(233, 254).to_gtp(255), "JJ1");
595 }
596}