1use serde::Serialize;
12
13use crate::{
14 Client,
15 config::Config,
16 error::OpenAIError,
17 steps::Steps,
18 types::{
19 AssistantEventStream, CreateRunRequest, ListRunsResponse, ModifyRunRequest, RunObject,
20 SubmitToolOutputsRunRequest,
21 },
22};
23
24pub struct Runs<'c, C: Config> {
28 pub thread_id: String,
29 client: &'c Client<C>,
30}
31
32impl<'c, C: Config> Runs<'c, C> {
33 pub fn new(client: &'c Client<C>, thread_id: &str) -> Self {
34 Self {
35 client,
36 thread_id: thread_id.into(),
37 }
38 }
39
40 pub fn steps(&self, run_id: &str) -> Steps<C> {
42 Steps::new(self.client, &self.thread_id, run_id)
43 }
44
45 #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
47 pub async fn create(&self, request: CreateRunRequest) -> Result<RunObject, OpenAIError> {
48 self.client
49 .post(&format!("/threads/{}/runs", self.thread_id), request)
50 .await
51 }
52
53 #[crate::byot(
57 T0 = serde::Serialize,
58 R = serde::de::DeserializeOwned,
59 stream = "true",
60 where_clause = "R: std::marker::Send + 'static + TryFrom<eventsource_stream::Event, Error = OpenAIError>"
61 )]
62 #[allow(unused_mut)]
63 pub async fn create_stream(
64 &self,
65 mut request: CreateRunRequest,
66 ) -> Result<AssistantEventStream, OpenAIError> {
67 #[cfg(not(feature = "byot"))]
68 {
69 if request.stream.is_some() && !request.stream.unwrap() {
70 return Err(OpenAIError::InvalidArgument(
71 "When stream is false, use Runs::create".into(),
72 ));
73 }
74
75 request.stream = Some(true);
76 }
77
78 Ok(self
79 .client
80 .post_stream_mapped_raw_events(
81 &format!("/threads/{}/runs", self.thread_id),
82 request,
83 TryFrom::try_from,
84 )
85 .await)
86 }
87
88 #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
90 pub async fn retrieve(&self, run_id: &str) -> Result<RunObject, OpenAIError> {
91 self.client
92 .get(&format!("/threads/{}/runs/{run_id}", self.thread_id))
93 .await
94 }
95
96 #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
98 pub async fn update(
99 &self,
100 run_id: &str,
101 request: ModifyRunRequest,
102 ) -> Result<RunObject, OpenAIError> {
103 self.client
104 .post(
105 &format!("/threads/{}/runs/{run_id}", self.thread_id),
106 request,
107 )
108 .await
109 }
110
111 #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)]
113 pub async fn list<Q>(&self, query: &Q) -> Result<ListRunsResponse, OpenAIError>
114 where
115 Q: Serialize + ?Sized,
116 {
117 self.client
118 .get_with_query(&format!("/threads/{}/runs", self.thread_id), &query)
119 .await
120 }
121
122 #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)]
124 pub async fn submit_tool_outputs(
125 &self,
126 run_id: &str,
127 request: SubmitToolOutputsRunRequest,
128 ) -> Result<RunObject, OpenAIError> {
129 self.client
130 .post(
131 &format!(
132 "/threads/{}/runs/{run_id}/submit_tool_outputs",
133 self.thread_id
134 ),
135 request,
136 )
137 .await
138 }
139
140 #[crate::byot(
142 T0 = std::fmt::Display,
143 T1 = serde::Serialize,
144 R = serde::de::DeserializeOwned,
145 stream = "true",
146 where_clause = "R: std::marker::Send + 'static + TryFrom<eventsource_stream::Event, Error = OpenAIError>"
147 )]
148 #[allow(unused_mut)]
149 pub async fn submit_tool_outputs_stream(
150 &self,
151 run_id: &str,
152 mut request: SubmitToolOutputsRunRequest,
153 ) -> Result<AssistantEventStream, OpenAIError> {
154 #[cfg(not(feature = "byot"))]
155 {
156 if request.stream.is_some() && !request.stream.unwrap() {
157 return Err(OpenAIError::InvalidArgument(
158 "When stream is false, use Runs::submit_tool_outputs".into(),
159 ));
160 }
161
162 request.stream = Some(true);
163 }
164
165 Ok(self
166 .client
167 .post_stream_mapped_raw_events(
168 &format!(
169 "/threads/{}/runs/{run_id}/submit_tool_outputs",
170 self.thread_id
171 ),
172 request,
173 TryFrom::try_from,
174 )
175 .await)
176 }
177
178 #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)]
180 pub async fn cancel(&self, run_id: &str) -> Result<RunObject, OpenAIError> {
181 self.client
182 .post(
183 &format!("/threads/{}/runs/{run_id}/cancel", self.thread_id),
184 (),
185 )
186 .await
187 }
188}