openai_orch/lib.rs
1//! A concurrency-included Rust client for the OpenAI API.
2//!
3//! # Overview
4//! `openai-orch` is designed to provide a simple interface for sending requests
5//! to OpenAI in bulk, while managing concurrency at a global level. It also
6//! provides configurable policies to control how concurrency, timeouts, and
7//! retries are handled.
8//!
9//! # Usage
10//! To use this library, create an `Orchestrator` with the desired policies and
11//! keys. To allow a thread to use the `Orchestrator`, simply clone it. To send
12//! a request, call `add_request` on the `Orchestrator`, and then call get_response
13//! on the `Orchestrator` with the request ID returned by `add_request`. The
14//! `Orchestrator` will handle concurrency automatically.
15//!
16//! # Example
17//! ```rust
18//! use openai_orch::prelude::*;
19//!
20//! #[tokio::main]
21//! async fn main() {
22//! let policies = Policies::default();
23//! let keys = Keys::from_env().unwrap();
24//! let orchestrator = Orchestrator::new(policies, keys);
25//!
26//! let request = ChatSisoRequest::new(
27//! "You are a helpful assistant.".to_string(),
28//! "What are you?".to_string(),
29//! Default::default(),
30//! );
31//! let request_id = orchestrator.add_request(request).await;
32//!
33//! let response = orchestrator
34//! .get_response::<ChatSisoResponse>(request_id)
35//! .await;
36//! println!("{}", response.unwrap());
37//! }
38//! ```
39//!
40//! If you'd like, you can implement `OrchRequest` on your own request type.
41//! See the `OrchRequest` trait for more information. Currently the only request
42//! type implemented is `ChatSisoRequest`; `SISO` stands for "Single Input Single
43//! Output".
44
45pub mod chat;
46pub mod embed;
47pub mod keys;
48pub mod policies;
49pub mod prelude;
50pub mod utils;
51
52use std::{any::Any, collections::HashMap, marker::PhantomData, sync::Arc};
53
54use anyhow::{Error, Result};
55use async_trait::async_trait;
56use tinyrand::Rand;
57use tinyrand_std::thread_rand;
58use tokio::sync::{mpsc, Mutex, Semaphore};
59
60use crate::{keys::Keys, policies::Policies};
61
62pub trait ResponseType: 'static + Send {}
63
64/// Allows a request type to be used with the `Orchestrator`.
65#[async_trait]
66pub trait OrchRequest {
67 /// The type of response returned by the request.
68 type Res: ResponseType;
69 /// Business logic of a request. Given the policies, keys, and request ID
70 /// (for debugging, send the request and return the response.
71 async fn send(
72 &self,
73 policies: Policies,
74 keys: Keys,
75 id: u64,
76 ) -> Result<Self::Res>;
77}
78
79/// A unique identifier for a request.
80#[derive(Clone, Copy)]
81pub struct RequestID<R: ResponseType> {
82 id: u64,
83 _marker: PhantomData<R>,
84}
85
86type ResponseReceiver = mpsc::Receiver<Result<Box<dyn Any + Send>>>;
87
88/// The central interface for `openai_orch`. The `Orchestrator` is responsible
89/// for managing the concurrency of requests and their responses.
90///
91/// Using the `Orchestrator` is simple:
92/// 1. Create an `Orchestrator` with the desired policies and keys.
93/// 2. Create a request type that implements `OrchRequest` (optional).
94/// 3. Call `add_request` on the `Orchestrator` with the request handler.
95/// 4. Call `get_response` on the `Orchestrator` with the request ID returned by
96/// `add_request`.
97///
98/// The `Orchestrator` will handle the concurrency of requests and responses
99/// automatically.
100///
101/// To use the `Orchestrator` in multiple parts of your application, you can
102/// clone it. The `Orchestrator` is backed by an `Arc`, so cloning it is cheap.
103///
104/// ```rust
105/// use openai_orch::{
106/// chat::siso::{ChatSisoRequest, ChatSisoResponse},
107/// keys::Keys,
108/// policies::Policies,
109/// Orchestrator,
110/// };
111///
112/// #[tokio::main]
113/// async fn main() {
114/// let policies = Policies::default();
115/// let keys = Keys::from_env().unwrap();
116/// let orchestrator = Orchestrator::new(policies, keys);
117///
118/// let request = ChatSisoRequest::new(
119/// "You are a helpful assistant.".to_string(),
120/// "What are you?".to_string(),
121/// Default::default(),
122/// );
123/// let request_id = orchestrator.add_request(request).await;
124///
125/// let response = orchestrator
126/// .get_response::<ChatSisoResponse>(request_id)
127/// .await;
128/// println!("{}", response.unwrap());
129/// }
130/// ```
131#[derive(Clone)]
132pub struct Orchestrator {
133 requests: Arc<Mutex<HashMap<u64, ResponseReceiver>>>,
134 semaphore: Arc<Semaphore>,
135 policies: Policies,
136 keys: Keys,
137}
138
139impl Orchestrator {
140 /// Create a new `Orchestrator` with the given policies and keys.
141 pub fn new(policies: Policies, keys: Keys) -> Self {
142 Self {
143 requests: Arc::new(Mutex::new(HashMap::new())),
144 semaphore: Arc::new(Semaphore::new(
145 policies.concurrency_policy.max_concurrent_requests,
146 )),
147 policies,
148 keys,
149 }
150 }
151
152 /// Add a request to the `Orchestrator`. Returns a request ID that can be used
153 /// to get the response.
154 ///
155 /// Behind the scenes the `Orchestrator` will create a task for the request
156 /// using the `OrchRequest`'s `send` method when the concurrency policy
157 /// allows it. The result will be sent back to the `Orchestrator` using a
158 /// channel which is mapped to the request ID.
159 pub async fn add_request<R, Req>(&self, request: Req) -> RequestID<R>
160 where
161 Req: OrchRequest<Res = R> + Send + Sync + 'static,
162 R: ResponseType,
163 {
164 let id = thread_rand().next_u64();
165 let (tx, rx) = mpsc::channel(1);
166 self.requests.lock().await.insert(id, rx);
167
168 let semaphore = self.semaphore.clone();
169 let policies = self.policies.clone();
170 let keys = self.keys.clone();
171
172 tokio::spawn(async move {
173 let _permit = semaphore
174 .acquire()
175 .await
176 .expect("failed to acquire semaphore; this is UB");
177
178 let res = request
179 .send(policies, keys, id)
180 .await
181 .map(|res| Box::new(res) as Box<dyn Any + Send>);
182 let _ = tx.send(res).await;
183 });
184
185 RequestID {
186 id,
187 _marker: PhantomData,
188 }
189 }
190
191 /// Get the response for a given request ID.
192 ///
193 /// This will block until the response is received.
194 ///
195 /// Behind the scenes, this listens on a channel for a task to send the
196 /// response back to the `Orchestrator`. Once the response is received, it is
197 /// returned.
198 pub async fn get_response<R: ResponseType>(
199 &self,
200 request_id: RequestID<R>,
201 ) -> Result<R> {
202 let mut rx = self
203 .requests
204 .lock()
205 .await
206 .remove(&request_id.id)
207 .ok_or_else(|| Error::msg("No response receiver found"))?;
208
209 rx.recv()
210 .await
211 .ok_or_else(|| Error::msg("No response found"))?
212 .map(|res| *res.downcast::<R>().expect("Failed to downcast response"))
213 }
214}