katago_analysis/lib.rs
1//! A wrapper around KataGo's analysis protocol.
2//!
3//! See [KataGo Parallel Analysis Engine](https://github.com/lightvector/KataGo/blob/master/docs/Analysis_Engine.md)
4//! for official documentation of the analysis engine.
5//!
6//! Note: The asynchronous methods in this library must be called from within a Tokio runtime.
7//!
8//! # Examples
9//!
10//! After launching an [`Engine`](engine::Engine), the primary entry point for using this library is [`Analyzer`].
11//!
12//! ```
13//! use katago_analysis::{
14//! AnalysisRequest, Analyzer, Coord, Move, Player, Result, Rules,
15//! engine::{Engine, LaunchOptions},
16//! };
17//!
18//! async fn example(
19//! katago_path: String,
20//! analysis_config_path: String,
21//! model_path: String,
22//! ) -> Result<()> {
23//! let options = LaunchOptions::new(katago_path, analysis_config_path, model_path);
24//! let mut analyzer: Analyzer = Engine::launch(&options)?.into();
25//!
26//! let request = AnalysisRequest::new(
27//! Rules::chinese(),
28//! 19,
29//! 19,
30//! vec![
31//! (Player::Black, Move::Move(Coord(15, 3))),
32//! (Player::White, Move::Move(Coord(3, 15))),
33//! ],
34//! );
35//!
36//! let results = analyzer.analyze_game(request).await?;
37//! for i in 0..results.len() {
38//! println!(
39//! "Move {i}: {:.1}%",
40//! results.get(&i).unwrap().root_info.winrate * 100.0
41//! );
42//! }
43//! Ok(())
44//! }
45//! ```
46
47#![warn(missing_docs)]
48
49use std::{io, sync::Arc};
50
51use serde::{Deserialize, Serialize};
52use thiserror::Error;
53
54pub mod engine;
55
56mod analyzer;
57pub use analyzer::*;
58
59mod config;
60pub use config::*;
61
62mod request;
63pub use request::*;
64
65mod result;
66pub use result::*;
67
68mod rules;
69pub use rules::*;
70
71/// The type of results returned by methods in this library.
72pub type Result<T> = std::result::Result<T, Error>;
73
74/// The type of results which may contain warnings returned by the analysis engine.
75///
76/// See also: [`WarningHandling`]
77pub type WarningResult<T, W = WarningsAsErrors> = Result<<W as WarningHandling>::OkType<T>>;
78
79/// Errors that can occur while interacting with the analysis engine.
80#[derive(Debug, Clone, Error)]
81pub enum Error {
82 /// An I/O error occurred while launching, writing to, or reading from the analysis engine.
83 #[error("I/O error: {0}")]
84 Io(#[from] Arc<io::Error>),
85
86 /// An error occurred while serializing or deserializing a message.
87 #[error("serialization error: {0}")]
88 Serialization(#[from] Arc<serde_json::Error>),
89
90 /// The engine's stdin was unavailable after launch.
91 #[error("stdin unavailable")]
92 StdinUnavailable,
93
94 /// The engine's stdout was unavailable after launch.
95 #[error("stdout unavailable")]
96 StdoutUnavailable,
97
98 /// The specified config setting's value could not be serialized.
99 #[error("unserializable config value for key {0}")]
100 UnserializableConfig(String),
101
102 /// The analysis engine returned an error response without specifying which request caused it.
103 ///
104 /// When this error occurs, all pending requests will return this error, even if they might have otherwise
105 /// succeeded.
106 #[error("invalid request: {error}")]
107 KataGoGeneralError {
108 /// The error message provided by KataGo.
109 error: String,
110 },
111
112 /// The analysis engine returned an error response.
113 ///
114 /// When this error occurs, all positions still being analyzed as part of the associated request will return this
115 /// error, even if they might have otherwise succeeded.
116 #[error("invalid field {field}: {error}")]
117 KataGoFieldError {
118 /// The error message provided by KataGo.
119 error: String,
120
121 /// The request field that caused the error.
122 field: String,
123 },
124
125 /// The analysis engine returned a warning response which was converted to an error.
126 ///
127 /// See also: [`WarningHandling`]
128 #[error("unhandled warnings: {0:?}")]
129 UnhandledWarnings(Vec<Warning>),
130}
131
132impl From<io::Error> for Error {
133 fn from(e: io::Error) -> Self {
134 Error::Io(Arc::new(e))
135 }
136}
137
138impl From<serde_json::Error> for Error {
139 fn from(e: serde_json::Error) -> Self {
140 Error::Serialization(Arc::new(e))
141 }
142}
143
144/// A warning returned by the analysis engine.
145#[derive(Debug, Clone)]
146pub struct Warning {
147 /// The warning message provided by KataGo.
148 pub warning: String,
149
150 /// The request field that caused the warning.
151 pub field: String,
152}
153
154/// Specifies how warnings from the analysis engine should be handled.
155///
156/// # Warning Handling
157///
158/// [`Analyzer`] can handle analysis engine warnings in several ways. The `Ok` result type of methods that can produce
159/// warnings will vary depending on which strategy is chosen.
160///
161/// - [`Analyzer<WarningsAsErrors>`](WarningsAsErrors) will return the successful result directly, or return [`Error::UnhandledWarnings`]
162/// when a warning occurs. This is the default strategy.
163/// - [`Analyzer<ReturnWarnings>`](ReturnWarnings) will return a result wrapped in a [`MaybeWarnings`].
164/// - [`Analyzer<IgnoreWarnings>`](IgnoreWarnings) will return the successful result directly, and ignore any warnings that occur.
165/// This is not recommended unless you expect warnings to occur and don't intend to handle them.
166///
167/// # Examples
168///
169/// ### [`WarningsAsErrors`]
170///
171/// ```
172/// # use katago_analysis::*;
173/// async fn example(
174/// analyzer: &mut Analyzer,
175/// request: AnalysisRequest
176/// ) -> Result<()> {
177/// let result: AnalysisResult = analyzer
178/// .analyze(request)
179/// .await? // Warnings will cause this to fail with an error
180/// .unwrap();
181/// println!("{:.1}%", result.root_info.winrate * 100.0);
182/// Ok(())
183/// }
184/// ```
185///
186/// ### [`ReturnWarnings`]
187///
188/// ```
189/// # use katago_analysis::*;
190/// async fn example(
191/// analyzer: &mut Analyzer<ReturnWarnings>,
192/// request: AnalysisRequest,
193/// ) -> Result<()> {
194/// let result: AnalysisResult = analyzer
195/// .analyze(request)
196/// .await?
197/// .inspect_warnings(|warnings| {
198/// for warning in warnings {
199/// println!("{warning:?}");
200/// }
201/// })
202/// .value
203/// .unwrap();
204/// println!("{:.1}%", result.root_info.winrate * 100.0);
205/// Ok(())
206/// }
207/// ```
208///
209/// Warnings can also be converted back to errors:
210///
211/// ```
212/// # use katago_analysis::*;
213/// async fn example(
214/// analyzer: &mut Analyzer<ReturnWarnings>,
215/// request: AnalysisRequest,
216/// ) -> Result<()> {
217/// let result: AnalysisResult = analyzer
218/// .analyze(request)
219/// .await?
220/// .into_result()? // Warnings will cause this to fail with an error
221/// .unwrap();
222/// println!("{:.1}%", result.root_info.winrate * 100.0);
223/// Ok(())
224/// }
225/// ```
226///
227/// ### [`IgnoreWarnings`]
228///
229/// ```
230/// # use katago_analysis::*;
231/// async fn example(
232/// analyzer: &mut Analyzer<IgnoreWarnings>,
233/// request: AnalysisRequest
234/// ) -> Result<()> {
235/// let result: AnalysisResult = analyzer
236/// .analyze(request)
237/// .await? // Warnings will be ignored
238/// .unwrap();
239/// println!("{:.1}%", result.root_info.winrate * 100.0);
240/// Ok(())
241/// }
242/// ```
243pub trait WarningHandling {
244 /// The `Ok` result type.
245 type OkType<T>;
246
247 /// Creates a successful result containing the given value and no warnings.
248 fn ok<T>(val: T) -> WarningResult<T, Self>;
249
250 /// Updates the successful result, preserving errors and warnings.
251 fn set_result<T>(result: &mut WarningResult<T, Self>, value: T);
252
253 /// Adds a new warning to the result.
254 fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning);
255
256 /// Applies a function to the successful values of two results, merging warnings and errors.
257 fn merge<T, U, V>(
258 a: WarningResult<T, Self>,
259 b: WarningResult<U, Self>,
260 f: impl FnOnce(T, U) -> V,
261 ) -> WarningResult<V, Self>;
262}
263
264/// Warnings will return [`Error::UnhandledWarnings`].
265///
266/// See also: [`WarningHandling`]
267#[derive(Debug, Default, Clone)]
268pub struct WarningsAsErrors;
269
270impl WarningHandling for WarningsAsErrors {
271 type OkType<T> = T;
272
273 fn ok<T>(val: T) -> WarningResult<T, Self> {
274 Ok(val)
275 }
276
277 fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
278 if let Ok(r) = result {
279 *r = value;
280 }
281 }
282
283 fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning) {
284 match result {
285 Ok(_) => *result = Err(Error::UnhandledWarnings(vec![warning])),
286 Err(Error::UnhandledWarnings(warnings)) => warnings.push(warning),
287 _ => {}
288 }
289 }
290
291 fn merge<T, U, V>(
292 a: WarningResult<T, Self>,
293 b: WarningResult<U, Self>,
294 f: impl FnOnce(T, U) -> V,
295 ) -> WarningResult<V, Self> {
296 Ok(f(a?, b?))
297 }
298}
299
300/// Warnings will be returned in a [`MaybeWarnings`].
301///
302/// See also: [`WarningHandling`]
303#[derive(Debug, Default, Clone)]
304pub struct ReturnWarnings;
305
306impl WarningHandling for ReturnWarnings {
307 type OkType<T> = MaybeWarnings<T>;
308
309 fn ok<T>(val: T) -> WarningResult<T, Self> {
310 Ok(MaybeWarnings {
311 value: val,
312 warnings: None,
313 })
314 }
315
316 fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
317 if let Ok(r) = result {
318 r.value = value;
319 }
320 }
321
322 fn add_warning<T>(result: &mut WarningResult<T, Self>, warning: Warning) {
323 if let Ok(r) = result {
324 match r.warnings.as_mut() {
325 Some(warnings) => warnings.push(warning),
326 None => r.warnings = Some(vec![warning]),
327 }
328 }
329 }
330
331 fn merge<T, U, V>(
332 a: WarningResult<T, Self>,
333 b: WarningResult<U, Self>,
334 f: impl FnOnce(T, U) -> V,
335 ) -> WarningResult<V, Self> {
336 let MaybeWarnings {
337 value: a,
338 warnings: a_warnings,
339 } = a?;
340 let MaybeWarnings {
341 value: b,
342 warnings: b_warnings,
343 } = b?;
344 Ok(MaybeWarnings {
345 value: f(a, b),
346 warnings: a_warnings.or(b_warnings),
347 })
348 }
349}
350
351/// A result that may contain warnings.
352#[derive(Debug, Default, Clone)]
353pub struct MaybeWarnings<T> {
354 /// The successful result value.
355 pub value: T,
356
357 /// The list of warnings that occurred, if any.
358 pub warnings: Option<Vec<Warning>>,
359}
360
361impl<T> MaybeWarnings<T> {
362 /// Extracts the successful result value, returning [`Error::UnhandledWarnings`] if any warnings occurred.
363 pub fn into_result(self) -> Result<T> {
364 match self.warnings {
365 Some(warnings) => Err(Error::UnhandledWarnings(warnings)),
366 None => Ok(self.value),
367 }
368 }
369
370 /// Calls a function with a reference to the warnings, if any.
371 pub fn inspect_warnings<F: FnOnce(&Vec<Warning>)>(self, f: F) -> Self {
372 if let Some(warnings) = &self.warnings {
373 f(warnings);
374 }
375 self
376 }
377}
378
379/// Warnings will be dropped.
380///
381/// See also: [`WarningHandling`]
382#[derive(Debug, Default, Clone)]
383pub struct IgnoreWarnings;
384
385impl WarningHandling for IgnoreWarnings {
386 type OkType<T> = T;
387
388 fn ok<T>(val: T) -> WarningResult<T, Self> {
389 Ok(val)
390 }
391
392 fn set_result<T>(result: &mut WarningResult<T, Self>, value: T) {
393 if let Ok(r) = result {
394 *r = value;
395 }
396 }
397
398 fn add_warning<T>(_result: &mut WarningResult<T, Self>, _warning: Warning) {}
399
400 fn merge<T, U, V>(
401 a: WarningResult<T, Self>,
402 b: WarningResult<U, Self>,
403 f: impl FnOnce(T, U) -> V,
404 ) -> WarningResult<V, Self> {
405 Ok(f(a?, b?))
406 }
407}
408
409/// Player colours.
410#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
411pub enum Player {
412 /// The black player.
413 #[serde(rename = "B")]
414 Black,
415 /// The white player.
416 #[serde(rename = "W")]
417 White,
418}
419
420/// A board location in (x, y) format, where (0, 0) is the top-left corner of the board.
421#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
422pub struct Coord(pub u8, pub u8);
423
424impl Coord {
425 /// Converts a coordinate from GTP format.
426 pub fn from_gtp(s: &str, height: u8) -> Option<Self> {
427 let (x_part, y_part) = s
428 .chars()
429 .partition::<String, _>(|c| c.is_ascii_alphabetic());
430
431 let mut x = 0;
432 for c in x_part.chars().map(|c| c.to_ascii_uppercase()) {
433 x *= 25;
434 x += (c as u8) - b'A' + 1;
435 if c == 'I' {
436 return None;
437 } else if c > 'I' {
438 x -= 1;
439 }
440 }
441 x = x.checked_sub(1)?;
442 let y = height.checked_sub(y_part.parse::<u8>().ok()?)?;
443 Some(Self(x, y))
444 }
445
446 /// Converts a coordinate to GTP format.
447 pub fn to_gtp(self, height: u8) -> String {
448 let Self(mut x, y) = self;
449 const LETTERS: &[u8; 25] = b"ABCDEFGHJKLMNOPQRSTUVWXYZ";
450 let mut gtp = String::with_capacity(4);
451 if x >= 25 {
452 gtp.push(LETTERS[(x / 25) as usize - 1] as char);
453 x %= 25;
454 }
455 gtp.push(LETTERS[x as usize] as char);
456 gtp.push_str(&(height - y).to_string());
457 gtp
458 }
459}
460
461/// A move in a game.
462#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
463pub enum Move {
464 /// A move placing a stone at the specified coordinate.
465 Move(Coord),
466
467 /// A pass.
468 Pass,
469}
470
471impl Move {
472 /// Converts a move from GTP format.
473 pub fn from_gtp(s: &str, height: u8) -> Option<Self> {
474 if s.to_ascii_lowercase() == "pass" {
475 Some(Self::Pass)
476 } else {
477 Coord::from_gtp(s, height).map(Self::Move)
478 }
479 }
480
481 /// Converts a move to GTP format.
482 pub fn to_gtp(self, height: u8) -> String {
483 match self {
484 Move::Move(coord) => coord.to_gtp(height),
485 Move::Pass => "pass".to_string(),
486 }
487 }
488}
489
490/// KataGo's version information.
491#[derive(Debug, Clone)]
492pub struct VersionInfo {
493 /// A string indicating the most recent KataGo release version that this version is a descendant of,
494 /// such as `"1.6.1"`.
495 pub version: String,
496
497 /// The precise git hash this KataGo version was compiled from, or the string `"<omitted>"` if KataGo was
498 /// compiled separately from its repo or without Git support.
499 pub git_hash: String,
500}
501
502/// Information about a neural network model.
503#[derive(Debug, Clone, Deserialize)]
504#[serde(rename_all = "camelCase")]
505pub struct Model {
506 /// The model name.
507 pub name: String,
508
509 /// The internal name.
510 pub internal_name: String,
511
512 /// The maximum batch size.
513 pub max_batch_size: u32,
514
515 /// Whether it uses a humanSL profile.
516 #[serde(rename = "usesHumanSLProfile")]
517 pub uses_humansl_profile: bool,
518
519 /// The model version.
520 pub version: u32,
521
522 /// Whether FP16 is used for this model. If this is [`Auto`][Enabled::Auto],
523 /// it will be enabled if the backend deems it to be beneficial.
524 #[serde(rename = "usingFP16")]
525 pub using_fp16: Enabled,
526}
527
528/// The enabled state of a feature.
529#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize)]
530#[serde(rename_all = "lowercase")]
531pub enum Enabled {
532 /// The feature is disabled.
533 False,
534
535 /// The feature is enabled.
536 True,
537
538 /// The feature will be automatically enabled or disabled based on what the engine thinks is best.
539 Auto,
540}
541
542#[cfg(test)]
543mod tests {
544 use crate::Coord;
545
546 #[test]
547 fn coord_from_gtp() {
548 assert_eq!(Coord::from_gtp("A1", 19), Some(Coord(0, 18)));
549 assert_eq!(Coord::from_gtp("T19", 19), Some(Coord(18, 0)));
550 assert_eq!(Coord::from_gtp("I9", 19), None);
551 assert_eq!(Coord::from_gtp("1", 19), None);
552 assert_eq!(Coord::from_gtp("A", 19), None);
553 assert_eq!(Coord::from_gtp("A20", 19), None);
554 assert_eq!(Coord::from_gtp("Z1", 255), Some(Coord(24, 254)));
555 assert_eq!(Coord::from_gtp("AA1", 255), Some(Coord(25, 254)));
556 assert_eq!(Coord::from_gtp("BB1", 255), Some(Coord(51, 254)));
557 assert_eq!(Coord::from_gtp("JJ1", 255), Some(Coord(233, 254)));
558 }
559
560 #[test]
561 fn coord_to_gtp() {
562 assert_eq!(Coord(0, 18).to_gtp(19), "A1");
563 assert_eq!(Coord(18, 0).to_gtp(19), "T19");
564 assert_eq!(Coord(24, 254).to_gtp(255), "Z1");
565 assert_eq!(Coord(25, 254).to_gtp(255), "AA1");
566 assert_eq!(Coord(51, 254).to_gtp(255), "BB1");
567 assert_eq!(Coord(233, 254).to_gtp(255), "JJ1");
568 }
569}