prover_config/
lib.rs

1use std::{str::FromStr, time::Duration};
2
3use prover_utils::{from_env_or_default, with};
4use serde::{Deserialize, Serialize};
5use serde_with::serde_as;
6use url::Url;
7
8/// The default url endpoint for the grpc cluster service
9const DEFAULT_SP1_CLUSTER_ENDPOINT: &str = "https://rpc.production.succinct.xyz/";
10
11/// Type of the prover to be used for generation of the pessimistic proof
12#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
13#[serde(rename_all = "kebab-case")]
14pub enum ProverType {
15    NetworkProver(NetworkProverConfig),
16    CpuProver(CpuProverConfig),
17    GpuProver(GpuProverConfig),
18    MockProver(MockProverConfig),
19}
20
21impl Default for ProverType {
22    fn default() -> Self {
23        ProverType::NetworkProver(NetworkProverConfig::default())
24    }
25}
26
27#[serde_as]
28#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
29#[serde(rename_all = "kebab-case")]
30pub struct CpuProverConfig {
31    #[serde(default = "default_max_concurrency_limit")]
32    pub max_concurrency_limit: usize,
33
34    #[serde_as(as = "Option<crate::with::HumanDuration>")]
35    pub proving_request_timeout: Option<Duration>,
36
37    #[serde(default = "default_local_proving_timeout")]
38    #[serde(with = "crate::with::HumanDuration")]
39    pub proving_timeout: Duration,
40}
41
42impl CpuProverConfig {
43    // This constant represents the number of second added to the proving_timeout
44    pub const DEFAULT_PROVING_TIMEOUT_PADDING: Duration = Duration::from_secs(1);
45
46    pub fn get_proving_request_timeout(&self) -> Duration {
47        self.proving_request_timeout
48            .unwrap_or_else(|| self.proving_timeout + Self::DEFAULT_PROVING_TIMEOUT_PADDING)
49    }
50}
51
52impl Default for CpuProverConfig {
53    fn default() -> Self {
54        Self {
55            max_concurrency_limit: default_max_concurrency_limit(),
56            proving_request_timeout: None,
57            proving_timeout: default_local_proving_timeout(),
58        }
59    }
60}
61
62#[serde_as]
63#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
64#[serde(rename_all = "kebab-case")]
65pub struct NetworkProverConfig {
66    #[serde_as(as = "Option<crate::with::HumanDuration>")]
67    pub proving_request_timeout: Option<Duration>,
68
69    #[serde(default = "default_network_proving_timeout")]
70    #[serde(with = "crate::with::HumanDuration")]
71    pub proving_timeout: Duration,
72
73    /// The sp1 proving cluster endpoint.
74    #[serde(default = "default_sp1_cluster_endpoint")]
75    pub sp1_cluster_endpoint: url::Url,
76}
77
78impl NetworkProverConfig {
79    // This constant represents the number of second added to the proving_timeout
80    pub const DEFAULT_PROVING_TIMEOUT_PADDING: Duration = Duration::from_secs(1);
81
82    pub fn get_proving_request_timeout(&self) -> Duration {
83        self.proving_request_timeout
84            .unwrap_or_else(|| self.proving_timeout + Self::DEFAULT_PROVING_TIMEOUT_PADDING)
85    }
86}
87
88impl Default for NetworkProverConfig {
89    fn default() -> Self {
90        Self {
91            proving_request_timeout: None,
92            proving_timeout: default_network_proving_timeout(),
93            sp1_cluster_endpoint: default_sp1_cluster_endpoint(),
94        }
95    }
96}
97
98#[serde_as]
99#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
100#[serde(rename_all = "kebab-case")]
101pub struct GpuProverConfig {
102    #[serde(default = "default_max_concurrency_limit")]
103    pub max_concurrency_limit: usize,
104
105    #[serde_as(as = "Option<crate::with::HumanDuration>")]
106    pub proving_request_timeout: Option<Duration>,
107
108    #[serde(default = "default_local_proving_timeout")]
109    #[serde(with = "crate::with::HumanDuration")]
110    pub proving_timeout: Duration,
111}
112
113impl GpuProverConfig {
114    // This constant represents the number of second added to the proving_timeout
115    pub const DEFAULT_PROVING_TIMEOUT_PADDING: Duration = Duration::from_secs(1);
116
117    pub fn get_proving_request_timeout(&self) -> Duration {
118        self.proving_request_timeout
119            .unwrap_or_else(|| self.proving_timeout + Self::DEFAULT_PROVING_TIMEOUT_PADDING)
120    }
121}
122
123impl Default for GpuProverConfig {
124    fn default() -> Self {
125        Self {
126            max_concurrency_limit: default_max_concurrency_limit(),
127            proving_request_timeout: None,
128            proving_timeout: default_local_proving_timeout(),
129        }
130    }
131}
132
133#[serde_as]
134#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
135#[serde(rename_all = "kebab-case")]
136pub struct MockProverConfig {
137    #[serde(default = "default_max_concurrency_limit")]
138    pub max_concurrency_limit: usize,
139
140    #[serde_as(as = "Option<crate::with::HumanDuration>")]
141    pub proving_request_timeout: Option<Duration>,
142
143    #[serde(default = "default_local_proving_timeout")]
144    #[serde(with = "crate::with::HumanDuration")]
145    pub proving_timeout: Duration,
146}
147
148impl MockProverConfig {
149    // This constant represents the number of second added to the proving_timeout
150    pub const DEFAULT_PROVING_TIMEOUT_PADDING: Duration = Duration::from_secs(1);
151
152    pub fn get_proving_request_timeout(&self) -> Duration {
153        self.proving_request_timeout
154            .unwrap_or_else(|| self.proving_timeout + Self::DEFAULT_PROVING_TIMEOUT_PADDING)
155    }
156}
157
158impl Default for MockProverConfig {
159    fn default() -> Self {
160        Self {
161            max_concurrency_limit: default_max_concurrency_limit(),
162            proving_request_timeout: None,
163            proving_timeout: default_local_proving_timeout(),
164        }
165    }
166}
167
168pub const fn default_max_concurrency_limit() -> usize {
169    100
170}
171
172const fn default_local_proving_timeout() -> Duration {
173    Duration::from_secs(60 * 5)
174}
175
176const fn default_network_proving_timeout() -> Duration {
177    Duration::from_secs(60 * 5)
178}
179
180fn default_sp1_cluster_endpoint() -> Url {
181    from_env_or_default(
182        "SP1_CLUSTER_ENDPOINT",
183        Url::from_str(DEFAULT_SP1_CLUSTER_ENDPOINT).unwrap(),
184    )
185}