Skip to main content

katago_analysis/engine/
request.rs

1use std::ops::Not;
2
3use serde::Serialize;
4use serde_json::{Value, json};
5use serde_with::skip_serializing_none;
6
7use crate::{Bonus, Config, Player, Rules};
8
9/// A request to the analysis engine.
10#[derive(Debug, Clone, Serialize)]
11#[serde(into = "Value")]
12#[expect(
13    clippy::large_enum_variant,
14    reason = "Boxing AnalysisRequest would be inconvenient, and very little would be gained"
15)]
16pub enum Request {
17    /// Request the engine to analyze one or more positions.
18    Analyze(AnalysisRequest),
19
20    /// Request KataGo's version information.
21    QueryVersion {
22        /// The request ID.
23        id: String,
24    },
25
26    /// Clear the neural network cache.
27    ClearCache {
28        /// The request ID.
29        id: String,
30    },
31
32    /// Terminate a specific analysis request.
33    Terminate {
34        /// The request ID.
35        id: String,
36
37        /// The ID of the request to terminate.
38        terminate_id: String,
39
40        /// If provided, only terminate the analysis for the specified turn numbers.
41        turn_numbers: Option<Vec<usize>>,
42    },
43
44    /// Terminate all pending analysis requests.
45    TerminateAll {
46        /// The request ID.
47        id: String,
48
49        /// If provided, only terminate the analysis for the specified turn numbers.
50        turn_numbers: Option<Vec<usize>>,
51    },
52
53    /// Request information about the available neural network models.
54    QueryModels {
55        /// The request ID.
56        id: String,
57    },
58}
59
60impl From<Request> for Value {
61    fn from(request: Request) -> Self {
62        match request {
63            Request::Analyze(request) => {
64                serde_json::to_value(request).expect("request should be serializable")
65            }
66            Request::QueryVersion { id } => json!({
67                "id": id,
68                "action": "query_version",
69            }),
70            Request::ClearCache { id } => json!({
71                "id": id,
72                "action": "clear_cache",
73            }),
74            Request::Terminate {
75                id,
76                terminate_id,
77                turn_numbers,
78            } => {
79                let mut value = json!({
80                        "id": id,
81                        "action": "terminate",
82                        "terminateId": terminate_id,
83                    }
84                );
85                if let Some(turn_numbers) = turn_numbers {
86                    value
87                        .as_object_mut()
88                        .expect("value should be an object")
89                        .insert("turnNumbers".to_string(), json!(turn_numbers));
90                }
91                value
92            }
93            Request::TerminateAll { id, turn_numbers } => {
94                let mut value = json!({
95                        "id": id,
96                        "action": "terminate_all",
97                    }
98                );
99                if let Some(turn_numbers) = turn_numbers {
100                    value
101                        .as_object_mut()
102                        .expect("value should be an object")
103                        .insert("turnNumbers".to_string(), json!(turn_numbers));
104                }
105                value
106            }
107            Request::QueryModels { id } => json!({
108                "id": id,
109                "action": "query_models",
110            }),
111        }
112    }
113}
114
115/// A game record to be analyzed, along with analysis settings.
116#[skip_serializing_none]
117#[derive(Debug, Clone, Serialize)]
118#[serde(rename_all = "camelCase")]
119pub struct AnalysisRequest {
120    /// The request ID.
121    pub id: String,
122
123    /// The ruleset for this game.
124    pub rules: Rules,
125
126    /// The komi for this game.
127    pub komi: Option<f64>,
128
129    /// Bonus points white receives in handicap games.
130    pub white_handicap_bonus: Option<Bonus>,
131
132    /// The board width.
133    pub board_x_size: u8,
134
135    /// The board height.
136    pub board_y_size: u8,
137
138    /// The stones on the board before the first move.
139    pub initial_stones: Option<Vec<(Player, String)>>,
140
141    /// The player to move in the initial position.
142    pub initial_player: Option<Player>,
143
144    /// The moves played in the game. Move locations can be in GTP format (`"A1"`, `"pass"`, etc.) or explicit
145    /// coordinates (`"(0,0)"`).
146    pub moves: Vec<(Player, String)>,
147
148    /// The positions to analyze, where 0 is the position before the first move.
149    /// If not provided, only the final position will be analyzed.
150    /// The engine will return a separate response for each position.
151    pub analyze_turns: Option<Vec<usize>>,
152
153    /// The maximum number of visits to use.
154    pub max_visits: Option<u32>,
155
156    /// Root policy temperature.
157    pub root_policy_temperature: Option<f64>,
158
159    /// Root FPU reduction max.
160    pub root_fpu_reduction_max: Option<f64>,
161
162    /// The maximum length of the principal variation to return, not including the first move.
163    #[serde(rename = "analysisPVLen")]
164    pub analysis_pv_len: Option<usize>,
165
166    /// Whether to return the ownership prediction.
167    #[serde(skip_serializing_if = "Not::not")]
168    pub include_ownership: bool,
169
170    /// Whether to return the standard deviation of the ownership prediction.
171    #[serde(skip_serializing_if = "Not::not")]
172    pub include_ownership_stdev: bool,
173
174    /// Whether to return the ownership prediction for each move.
175    #[serde(skip_serializing_if = "Not::not")]
176    pub include_moves_ownership: bool,
177
178    /// Whether to return the standard deviation of the ownership prediction for each move.
179    #[serde(skip_serializing_if = "Not::not")]
180    pub include_moves_ownership_stdev: bool,
181
182    /// Whether to return the neural network policy output.
183    #[serde(skip_serializing_if = "Not::not")]
184    pub include_policy: bool,
185
186    /// Whether to return the number of visits for each position in the principal variation.
187    #[serde(rename = "includePVVisits", skip_serializing_if = "Not::not")]
188    pub include_pv_visits: bool,
189
190    /// Whether to return the predicted probability that the game will have a void result.
191    #[serde(skip_serializing_if = "Not::not")]
192    pub include_no_result_value: bool,
193
194    /// Moves which are forbidden.
195    pub avoid_moves: Option<Vec<RestrictedMoves>>,
196
197    /// Moves which are allowed. If specified, all other moves are forbidden.
198    pub allow_moves: Option<Vec<RestrictedMoves>>,
199
200    /// Config overrides for this request.
201    pub override_settings: Option<Config>,
202
203    /// Report partial analysis results every this many seconds.
204    pub report_during_search_every: Option<f64>,
205
206    /// The priority of this request.
207    pub priority: Option<i32>,
208
209    /// The priorities of each position to analyze.
210    pub priorities: Option<Vec<i32>>,
211}
212
213impl AnalysisRequest {
214    /// Creates a new analysis request with the minimum required parameters.
215    pub fn new(
216        id: String,
217        rules: Rules,
218        board_x_size: u8,
219        board_y_size: u8,
220        moves: Vec<(Player, String)>,
221    ) -> Self {
222        Self {
223            id,
224            rules,
225            komi: None,
226            white_handicap_bonus: None,
227            board_x_size,
228            board_y_size,
229            initial_stones: None,
230            initial_player: None,
231            moves,
232            analyze_turns: None,
233            max_visits: None,
234            root_policy_temperature: None,
235            root_fpu_reduction_max: None,
236            analysis_pv_len: None,
237            include_ownership: false,
238            include_ownership_stdev: false,
239            include_moves_ownership: false,
240            include_moves_ownership_stdev: false,
241            include_policy: false,
242            include_pv_visits: false,
243            include_no_result_value: false,
244            avoid_moves: None,
245            allow_moves: None,
246            override_settings: None,
247            report_during_search_every: None,
248            priority: None,
249            priorities: None,
250        }
251    }
252
253    /// Sets komi.
254    pub fn with_komi(mut self, komi: f64) -> Self {
255        self.komi = Some(komi);
256        self
257    }
258
259    /// Sets white's handicap bonus.
260    pub fn with_white_handicap_bonus(mut self, bonus: Bonus) -> Self {
261        self.white_handicap_bonus = Some(bonus);
262        self
263    }
264
265    /// Sets the initial position before the first move.
266    pub fn with_initial_stones(mut self, initial_stones: Vec<(Player, String)>) -> Self {
267        self.initial_stones = Some(initial_stones);
268        self
269    }
270
271    /// Sets the player to move in the initial position.
272    pub fn with_initial_player(mut self, initial_player: Player) -> Self {
273        self.initial_player = Some(initial_player);
274        self
275    }
276
277    /// Analyzes the specified positions. The position before the first move is turn 0.
278    pub fn with_analyze_turns(mut self, analyze_turns: Vec<usize>) -> Self {
279        self.analyze_turns = Some(analyze_turns);
280        self
281    }
282
283    /// Sets the maximum number of visits to use.
284    pub fn with_max_visits(mut self, max_visits: u32) -> Self {
285        self.max_visits = Some(max_visits);
286        self
287    }
288
289    /// Sets the root policy temperature.
290    pub fn with_root_policy_temperature(mut self, root_policy_temperature: f64) -> Self {
291        self.root_policy_temperature = Some(root_policy_temperature);
292        self
293    }
294
295    /// Sets the root FPU reduction max.
296    pub fn with_root_fpu_reduction_max(mut self, root_fpu_reduction_max: f64) -> Self {
297        self.root_fpu_reduction_max = Some(root_fpu_reduction_max);
298        self
299    }
300
301    /// Sets the maximum length of the principal variation to return, not including the first move.
302    pub fn with_analysis_pv_len(mut self, analysis_pv_len: usize) -> Self {
303        self.analysis_pv_len = Some(analysis_pv_len);
304        self
305    }
306
307    /// Includes the ownership prediction.
308    pub fn with_ownership(mut self) -> Self {
309        self.include_ownership = true;
310        self
311    }
312
313    /// Includes the standard deviation of the ownership prediction.
314    pub fn with_ownership_stdev(mut self) -> Self {
315        self.include_ownership_stdev = true;
316        self
317    }
318
319    /// Includes the ownership prediction for each move.
320    pub fn with_moves_ownership(mut self) -> Self {
321        self.include_moves_ownership = true;
322        self
323    }
324
325    /// Includes the standard deviation of the ownership prediction for each move.
326    pub fn with_moves_ownership_stdev(mut self) -> Self {
327        self.include_moves_ownership_stdev = true;
328        self
329    }
330
331    /// Includes the neural network policy output.
332    pub fn with_policy(mut self) -> Self {
333        self.include_policy = true;
334        self
335    }
336
337    /// Includes the number of visits for each position in the principal variation.
338    pub fn with_pv_visits(mut self) -> Self {
339        self.include_pv_visits = true;
340        self
341    }
342
343    /// Includes the predicted probability that the game will have a void result.
344    pub fn with_no_result_value(mut self) -> Self {
345        self.include_no_result_value = true;
346        self
347    }
348
349    /// Sets moves which are forbidden.
350    pub fn with_avoid_moves(mut self, avoid_moves: Vec<RestrictedMoves>) -> Self {
351        self.avoid_moves = Some(avoid_moves);
352        self
353    }
354
355    /// Sets moves which are allowed.
356    pub fn with_allow_moves(mut self, allow_moves: Vec<RestrictedMoves>) -> Self {
357        self.allow_moves = Some(allow_moves);
358        self
359    }
360
361    /// Overrides config settings for this request.
362    pub fn with_override_settings(mut self, config: Config) -> Self {
363        self.override_settings = Some(config);
364        self
365    }
366
367    /// Gets partial analysis results every this many seconds.
368    pub fn with_report_during_search_every(mut self, seconds: f64) -> Self {
369        self.report_during_search_every = Some(seconds);
370        self
371    }
372
373    /// Sets the priority of this request.
374    pub fn with_priority(mut self, priority: i32) -> Self {
375        self.priority = Some(priority);
376        self
377    }
378
379    /// Sets the priorities of each position to analyze.
380    pub fn with_priorities(mut self, priorities: Vec<i32>) -> Self {
381        self.priorities = Some(priorities);
382        self
383    }
384}
385
386/// A list of moves that are either forbidden with [`AnalysisRequest::avoid_moves`] or allowed with
387/// [`AnalysisRequest::allow_moves`].
388#[derive(Debug, Clone, Serialize)]
389#[serde(rename_all = "camelCase")]
390pub struct RestrictedMoves {
391    /// The player the move restriction applies to.
392    pub player: Player,
393
394    /// The list of moves.
395    pub moves: Vec<String>,
396
397    /// The search depth within which the restriction applies.
398    pub until_depth: u32,
399}