1use std::collections::{BTreeMap, HashMap};
2
3use solana_sdk::{account::Account, pubkey::Pubkey, transaction::Transaction};
4use thiserror::Error;
5use uuid::Uuid;
6use zksvm_api_types::{
7 api::{
8 EndSessionRequest, PollSessionResponse, SendTransactionRequest, SendTransactionResponse,
9 StartSessionRequest, StartSessionResponse,
10 },
11 common::{AccountProxy, ProofType, PubkeyWrapper},
12 errors::ApiResponse,
13};
14
15use crate::http_client::HttpClient;
16
17#[derive(Debug, Error)]
18pub enum SVMClientError {
19 #[error("Invalid uuid: {0}")]
20 InvalidUuid(#[from] uuid::Error),
21
22 #[error("Failed to reach server: {0}")]
23 HttpClientError(#[from] reqwest::Error),
24
25 #[error("Failed to start session: {0}")]
26 StartSessionError(String),
27
28 #[error("Session ID is required but missing")]
29 MissingSessionId,
30
31 #[error("Failed to send transaction: {0}")]
32 SendTransactionError(String),
33
34 #[error("Failed to serialize transaction: {0}")]
35 SerializationError(#[from] bincode::Error),
36
37 #[error("Failed to poll session status: {0}")]
38 PollSessionError(String),
39}
40
41pub struct Session {
42 client: HttpClient,
43 api_key: Uuid,
44 session_id: Option<Uuid>,
45}
46
47impl Session {
48 pub fn new(server_url: &str, server_port: &str, api_key: &str) -> Result<Self, SVMClientError> {
50 let api_key = Uuid::parse_str(api_key)?;
51 Ok(Self {
52 client: HttpClient::new(&format!("{server_url}:{server_port}")),
53 api_key,
54 session_id: None,
55 })
56 }
57
58 pub fn existing(
60 server_url: &str,
61 server_port: &str,
62 api_key: &str,
63 session_id: &str,
64 ) -> Result<Self, SVMClientError> {
65 let session_id = Some(Uuid::parse_str(session_id)?);
66 Ok(Self {
67 session_id,
68 ..Self::new(server_url, server_port, api_key)?
69 })
70 }
71
72 pub async fn start(
74 &mut self,
75 genesis_accounts: HashMap<Pubkey, Account>,
76 ) -> Result<Uuid, SVMClientError> {
77 let accounts: BTreeMap<PubkeyWrapper, AccountProxy> = genesis_accounts
78 .iter()
79 .map(|(pubkey, acc)| (pubkey.into(), acc.into()))
80 .collect();
81
82 let req = Some(StartSessionRequest {
83 genesis_accounts: accounts,
84 });
85 let res: ApiResponse<StartSessionResponse> =
86 self.client.post("/session", req, self.api_key).await?;
87
88 match res {
89 ApiResponse::Success {
90 data: StartSessionResponse { session_id },
91 } => {
92 self.session_id = Some(session_id);
93 Ok(session_id)
94 }
95 ApiResponse::Error { message, .. } => Err(SVMClientError::StartSessionError(message)),
96 }
97 }
98
99 pub async fn send_transaction(&self, tx: Transaction) -> Result<u8, SVMClientError> {
101 let session_id = self.session_id.ok_or(SVMClientError::MissingSessionId)?;
102
103 let req = Some(SendTransactionRequest {
104 transaction: bincode::serialize(&tx)?,
105 });
106 let res: ApiResponse<SendTransactionResponse> = self
107 .client
108 .post(
109 &format!("/session/{session_id}/transactions"),
110 req,
111 self.api_key,
112 )
113 .await?;
114
115 match res {
116 ApiResponse::Success {
117 data:
118 SendTransactionResponse {
119 remaining_transactions_in_session,
120 },
121 } => Ok(remaining_transactions_in_session),
122 ApiResponse::Error { message, .. } => {
123 Err(SVMClientError::SendTransactionError(message))
124 }
125 }
126 }
127
128 pub async fn end(&self, proof_type: ProofType) -> Result<(), SVMClientError> {
130 let session_id = self.session_id.ok_or(SVMClientError::MissingSessionId)?;
131
132 let req = Some(EndSessionRequest { proof_type });
133 self.client
134 .put::<_, ApiResponse<()>>(&format!("/session/{session_id}"), req, self.api_key)
135 .await?;
136
137 Ok(())
138 }
139
140 pub async fn poll_status(&self) -> Result<PollSessionResponse, SVMClientError> {
142 let session_id = self.session_id.ok_or(SVMClientError::MissingSessionId)?;
143
144 let res: ApiResponse<PollSessionResponse> = self
145 .client
146 .get(&format!("/session/{session_id}"), self.api_key)
147 .await?;
148
149 match res {
150 ApiResponse::Success { data } => Ok(data),
151 ApiResponse::Error { message, .. } => Err(SVMClientError::PollSessionError(message)),
152 }
153 }
154}