openai_orch/
lib.rs

1//! A concurrency-included Rust client for the OpenAI API.
2//!
3//! # Overview
4//! `openai-orch` is designed to provide a simple interface for sending requests
5//! to OpenAI in bulk, while managing concurrency at a global level. It also
6//! provides configurable policies to control how concurrency, timeouts, and
7//! retries are handled.
8//!
9//! # Usage
10//! To use this library, create an `Orchestrator` with the desired policies and
11//! keys. To allow a thread to use the `Orchestrator`, simply clone it. To send
12//! a request, call `add_request` on the `Orchestrator`, and then call get_response
13//! on the `Orchestrator` with the request ID returned by `add_request`. The
14//! `Orchestrator` will handle concurrency automatically.
15//!
16//! # Example
17//! ```rust
18//! use openai_orch::prelude::*;
19//! 
20//! #[tokio::main]
21//! async fn main() {
22//!   let policies = Policies::default();
23//!   let keys = Keys::from_env().unwrap();
24//!   let orchestrator = Orchestrator::new(policies, keys);
25//! 
26//!   let request = ChatSisoRequest::new(
27//!     "You are a helpful assistant.".to_string(),
28//!     "What are you?".to_string(),
29//!     Default::default(),
30//!   );
31//!   let request_id = orchestrator.add_request(request).await;
32//! 
33//!   let response = orchestrator
34//!     .get_response::<ChatSisoResponse>(request_id)
35//!     .await;
36//!   println!("{}", response.unwrap());
37//! }
38//! ```
39//!
40//! If you'd like, you can implement `OrchRequest` on your own request type.
41//! See the `OrchRequest` trait for more information. Currently the only request
42//! type implemented is `ChatSisoRequest`; `SISO` stands for "Single Input Single
43//! Output".
44
45pub mod chat;
46pub mod embed;
47pub mod keys;
48pub mod policies;
49pub mod prelude;
50pub mod utils;
51
52use std::{any::Any, collections::HashMap, marker::PhantomData, sync::Arc};
53
54use anyhow::{Error, Result};
55use async_trait::async_trait;
56use tinyrand::Rand;
57use tinyrand_std::thread_rand;
58use tokio::sync::{mpsc, Mutex, Semaphore};
59
60use crate::{keys::Keys, policies::Policies};
61
62pub trait ResponseType: 'static + Send {}
63
64/// Allows a request type to be used with the `Orchestrator`.
65#[async_trait]
66pub trait OrchRequest {
67  /// The type of response returned by the request.
68  type Res: ResponseType;
69  /// Business logic of a request. Given the policies, keys, and request ID
70  /// (for debugging, send the request and return the response.
71  async fn send(
72    &self,
73    policies: Policies,
74    keys: Keys,
75    id: u64,
76  ) -> Result<Self::Res>;
77}
78
79/// A unique identifier for a request.
80#[derive(Clone, Copy)]
81pub struct RequestID<R: ResponseType> {
82  id:      u64,
83  _marker: PhantomData<R>,
84}
85
86type ResponseReceiver = mpsc::Receiver<Result<Box<dyn Any + Send>>>;
87
88/// The central interface for `openai_orch`. The `Orchestrator` is responsible
89/// for managing the concurrency of requests and their responses.
90///
91/// Using the `Orchestrator` is simple:
92/// 1. Create an `Orchestrator` with the desired policies and keys.
93/// 2. Create a request type that implements `OrchRequest` (optional).
94/// 3. Call `add_request` on the `Orchestrator` with the request handler.
95/// 4. Call `get_response` on the `Orchestrator` with the request ID returned by
96///    `add_request`.
97///
98/// The `Orchestrator` will handle the concurrency of requests and responses
99/// automatically.
100///
101/// To use the `Orchestrator` in multiple parts of your application, you can
102/// clone it. The `Orchestrator` is backed by an `Arc`, so cloning it is cheap.
103///
104/// ```rust
105/// use openai_orch::{
106///   chat::siso::{ChatSisoRequest, ChatSisoResponse},
107///   keys::Keys,
108///   policies::Policies,
109///   Orchestrator,
110/// };
111/// 
112/// #[tokio::main]
113/// async fn main() {
114///   let policies = Policies::default();
115///   let keys = Keys::from_env().unwrap();
116///   let orchestrator = Orchestrator::new(policies, keys);
117/// 
118///   let request = ChatSisoRequest::new(
119///     "You are a helpful assistant.".to_string(),
120///     "What are you?".to_string(),
121///     Default::default(),
122///   );
123///   let request_id = orchestrator.add_request(request).await;
124/// 
125///   let response = orchestrator
126///     .get_response::<ChatSisoResponse>(request_id)
127///     .await;
128///   println!("{}", response.unwrap());
129/// }
130/// ```
131#[derive(Clone)]
132pub struct Orchestrator {
133  requests:  Arc<Mutex<HashMap<u64, ResponseReceiver>>>,
134  semaphore: Arc<Semaphore>,
135  policies:  Policies,
136  keys:      Keys,
137}
138
139impl Orchestrator {
140  /// Create a new `Orchestrator` with the given policies and keys.
141  pub fn new(policies: Policies, keys: Keys) -> Self {
142    Self {
143      requests: Arc::new(Mutex::new(HashMap::new())),
144      semaphore: Arc::new(Semaphore::new(
145        policies.concurrency_policy.max_concurrent_requests,
146      )),
147      policies,
148      keys,
149    }
150  }
151
152  /// Add a request to the `Orchestrator`. Returns a request ID that can be used
153  /// to get the response.
154  ///
155  /// Behind the scenes the `Orchestrator` will create a task for the request
156  /// using the `OrchRequest`'s `send` method when the concurrency policy
157  /// allows it. The result will be sent back to the `Orchestrator` using a
158  /// channel which is mapped to the request ID.
159  pub async fn add_request<R, Req>(&self, request: Req) -> RequestID<R>
160  where
161    Req: OrchRequest<Res = R> + Send + Sync + 'static,
162    R: ResponseType,
163  {
164    let id = thread_rand().next_u64();
165    let (tx, rx) = mpsc::channel(1);
166    self.requests.lock().await.insert(id, rx);
167
168    let semaphore = self.semaphore.clone();
169    let policies = self.policies.clone();
170    let keys = self.keys.clone();
171
172    tokio::spawn(async move {
173      let _permit = semaphore
174        .acquire()
175        .await
176        .expect("failed to acquire semaphore; this is UB");
177
178      let res = request
179        .send(policies, keys, id)
180        .await
181        .map(|res| Box::new(res) as Box<dyn Any + Send>);
182      let _ = tx.send(res).await;
183    });
184
185    RequestID {
186      id,
187      _marker: PhantomData,
188    }
189  }
190
191  /// Get the response for a given request ID.
192  ///
193  /// This will block until the response is received.
194  ///
195  /// Behind the scenes, this listens on a channel for a task to send the
196  /// response back to the `Orchestrator`. Once the response is received, it is
197  /// returned.
198  pub async fn get_response<R: ResponseType>(
199    &self,
200    request_id: RequestID<R>,
201  ) -> Result<R> {
202    let mut rx = self
203      .requests
204      .lock()
205      .await
206      .remove(&request_id.id)
207      .ok_or_else(|| Error::msg("No response receiver found"))?;
208
209    rx.recv()
210      .await
211      .ok_or_else(|| Error::msg("No response found"))?
212      .map(|res| *res.downcast::<R>().expect("Failed to downcast response"))
213  }
214}