Skip to main content

claudius/types/
stop_reason.rs

1use serde::{Deserialize, Serialize};
2use std::fmt;
3use std::str::FromStr;
4
5/// Reasons why the model stopped generating a response.
6#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
7#[serde(rename_all = "snake_case")]
8pub enum StopReason {
9    /// The model reached the end of a generated turn
10    EndTurn,
11
12    /// The response reached the maximum token limit for the response
13    MaxTokens,
14
15    /// The model reached a specified stop sequence
16    StopSequence,
17
18    /// The model indicated it wants to use a tool
19    ToolUse,
20
21    /// The model paused in the middle of a turn
22    PauseTurn,
23
24    /// The model refused to respond due to safety or other considerations
25    Refusal,
26}
27
28impl fmt::Display for StopReason {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            StopReason::EndTurn => write!(f, "end_turn"),
32            StopReason::MaxTokens => write!(f, "max_tokens"),
33            StopReason::StopSequence => write!(f, "stop_sequence"),
34            StopReason::ToolUse => write!(f, "tool_use"),
35            StopReason::PauseTurn => write!(f, "pause_turn"),
36            StopReason::Refusal => write!(f, "refusal"),
37        }
38    }
39}
40
41/// Error returned when parsing an invalid stop reason string.
42///
43/// This error contains the invalid string value that could not be parsed
44/// into a valid `StopReason` variant.
45#[derive(Debug)]
46pub struct StopReasonParseError {
47    /// The invalid string value that could not be parsed.
48    pub invalid_value: String,
49}
50
51impl fmt::Display for StopReasonParseError {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        write!(f, "Unknown stop reason: {}", self.invalid_value)
54    }
55}
56
57impl std::error::Error for StopReasonParseError {}
58
59impl FromStr for StopReason {
60    type Err = StopReasonParseError;
61
62    fn from_str(s: &str) -> Result<Self, Self::Err> {
63        match s {
64            "end_turn" => Ok(StopReason::EndTurn),
65            "max_tokens" => Ok(StopReason::MaxTokens),
66            "stop_sequence" => Ok(StopReason::StopSequence),
67            "tool_use" => Ok(StopReason::ToolUse),
68            "pause_turn" => Ok(StopReason::PauseTurn),
69            "refusal" => Ok(StopReason::Refusal),
70            _ => Err(StopReasonParseError {
71                invalid_value: s.to_string(),
72            }),
73        }
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80
81    #[test]
82    fn serialization() {
83        let reason = StopReason::EndTurn;
84        let json = serde_json::to_string(&reason).unwrap();
85        assert_eq!(json, r#""end_turn""#);
86
87        let reason = StopReason::MaxTokens;
88        let json = serde_json::to_string(&reason).unwrap();
89        assert_eq!(json, r#""max_tokens""#);
90    }
91
92    #[test]
93    fn deserialization() {
94        let json = r#""end_turn""#;
95        let reason: StopReason = serde_json::from_str(json).unwrap();
96        assert_eq!(reason, StopReason::EndTurn);
97
98        let json = r#""stop_sequence""#;
99        let reason: StopReason = serde_json::from_str(json).unwrap();
100        assert_eq!(reason, StopReason::StopSequence);
101    }
102
103    #[test]
104    fn display() {
105        let reason = StopReason::EndTurn;
106        assert_eq!(reason.to_string(), "end_turn");
107
108        let reason = StopReason::MaxTokens;
109        assert_eq!(reason.to_string(), "max_tokens");
110    }
111}