bob_runtime/
tower_service.rs1use std::{
35 future::Future,
36 pin::Pin,
37 sync::Arc,
38 task::{Context, Poll},
39};
40
41use bob_core::{
42 error::{LlmError, ToolError},
43 ports::{LlmPort, ToolPort},
44 types::{LlmRequest, LlmResponse, ToolCall, ToolDescriptor, ToolResult},
45};
46
47#[derive(Debug, Clone)]
51pub struct ToolRequest {
52 pub call: ToolCall,
54}
55
56impl ToolRequest {
57 #[must_use]
59 pub fn new(call: ToolCall) -> Self {
60 Self { call }
61 }
62
63 #[must_use]
65 pub fn from_parts(name: impl Into<String>, arguments: serde_json::Value) -> Self {
66 Self { call: ToolCall::new(name, arguments) }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct ToolResponse {
73 pub result: ToolResult,
75}
76
77#[derive(Debug, Clone)]
79pub struct LlmRequestWrapper {
80 pub request: LlmRequest,
82}
83
84#[derive(Debug, Clone)]
86pub struct LlmResponseWrapper {
87 pub response: LlmResponse,
89}
90
91#[derive(Debug, Clone, Copy, Default)]
93pub struct ToolListRequest;
94
95pub struct ToolService {
102 inner: Arc<dyn ToolPort>,
103}
104
105impl ToolService {
106 #[must_use]
108 pub fn new(port: Arc<dyn ToolPort>) -> Self {
109 Self { inner: port }
110 }
111}
112
113impl std::fmt::Debug for ToolService {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("ToolService").finish_non_exhaustive()
116 }
117}
118
119impl tower::Service<ToolRequest> for ToolService {
120 type Response = ToolResponse;
121 type Error = ToolError;
122 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
123
124 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
125 Poll::Ready(Ok(()))
126 }
127
128 fn call(&mut self, req: ToolRequest) -> Self::Future {
129 let inner = self.inner.clone();
130 Box::pin(async move {
131 let result = inner.call_tool(req.call).await?;
132 Ok(ToolResponse { result })
133 })
134 }
135}
136
137pub struct LlmService {
143 inner: Arc<dyn LlmPort>,
144}
145
146impl LlmService {
147 #[must_use]
149 pub fn new(port: Arc<dyn LlmPort>) -> Self {
150 Self { inner: port }
151 }
152}
153
154impl std::fmt::Debug for LlmService {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 f.debug_struct("LlmService").finish_non_exhaustive()
157 }
158}
159
160impl tower::Service<LlmRequestWrapper> for LlmService {
161 type Response = LlmResponseWrapper;
162 type Error = LlmError;
163 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
164
165 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
166 Poll::Ready(Ok(()))
167 }
168
169 fn call(&mut self, req: LlmRequestWrapper) -> Self::Future {
170 let inner = self.inner.clone();
171 Box::pin(async move {
172 let response = inner.complete(req.request).await?;
173 Ok(LlmResponseWrapper { response })
174 })
175 }
176}
177
178pub struct ToolListService {
182 inner: Arc<dyn ToolPort>,
183}
184
185impl ToolListService {
186 #[must_use]
188 pub fn new(port: Arc<dyn ToolPort>) -> Self {
189 Self { inner: port }
190 }
191}
192
193impl std::fmt::Debug for ToolListService {
194 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195 f.debug_struct("ToolListService").finish_non_exhaustive()
196 }
197}
198
199impl tower::Service<ToolListRequest> for ToolListService {
200 type Response = Vec<ToolDescriptor>;
201 type Error = ToolError;
202 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
203
204 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
205 Poll::Ready(Ok(()))
206 }
207
208 fn call(&mut self, _req: ToolListRequest) -> Self::Future {
209 let inner = self.inner.clone();
210 Box::pin(async move { inner.list_tools().await })
211 }
212}
213
214pub trait ServiceExt<Request>: tower::Service<Request> + Sized {
221 fn with_timeout(self, timeout: std::time::Duration) -> tower::timeout::Timeout<Self> {
223 tower::timeout::Timeout::new(self, timeout)
224 }
225
226 fn with_rate_limit(
228 self,
229 max: u64,
230 interval: std::time::Duration,
231 ) -> tower::limit::RateLimit<Self> {
232 tower::limit::RateLimit::new(self, tower::limit::rate::Rate::new(max, interval))
233 }
234
235 fn with_concurrency_limit(self, max: usize) -> tower::limit::ConcurrencyLimit<Self> {
237 tower::limit::ConcurrencyLimit::new(self, max)
238 }
239
240 fn map_err<F, E2>(self, f: F) -> tower::util::MapErr<Self, F>
242 where
243 F: FnOnce(Self::Error) -> E2,
244 {
245 tower::util::MapErr::new(self, f)
246 }
247
248 fn map_response<F, Response2>(self, f: F) -> tower::util::MapResponse<Self, F>
250 where
251 F: FnOnce(Self::Response) -> Response2,
252 {
253 tower::util::MapResponse::new(self, f)
254 }
255
256 fn boxed(self) -> tower::util::BoxService<Request, Self::Response, Self::Error>
258 where
259 Self: Send + 'static,
260 Request: Send + 'static,
261 Self::Future: Send + 'static,
262 {
263 tower::util::BoxService::new(self)
264 }
265}
266
267impl<T, Request> ServiceExt<Request> for T where T: tower::Service<Request> + Sized {}
269
270pub trait ToolPortServiceExt: ToolPort {
276 fn into_tool_service(self: Arc<Self>) -> ToolService;
278
279 fn into_tool_list_service(self: Arc<Self>) -> ToolListService;
281}
282
283impl<T: ToolPort + 'static> ToolPortServiceExt for T {
285 fn into_tool_service(self: Arc<Self>) -> ToolService {
286 ToolService::new(self)
287 }
288
289 fn into_tool_list_service(self: Arc<Self>) -> ToolListService {
290 ToolListService::new(self)
291 }
292}
293
294impl ToolPortServiceExt for dyn ToolPort {
296 fn into_tool_service(self: Arc<Self>) -> ToolService {
297 ToolService::new(self)
298 }
299
300 fn into_tool_list_service(self: Arc<Self>) -> ToolListService {
301 ToolListService::new(self)
302 }
303}
304
305pub trait LlmPortServiceExt: LlmPort {
307 fn into_llm_service(self: Arc<Self>) -> LlmService;
309}
310
311impl<T: LlmPort + 'static> LlmPortServiceExt for T {
313 fn into_llm_service(self: Arc<Self>) -> LlmService {
314 LlmService::new(self)
315 }
316}
317
318impl LlmPortServiceExt for dyn LlmPort {
320 fn into_llm_service(self: Arc<Self>) -> LlmService {
321 LlmService::new(self)
322 }
323}
324
325#[cfg(test)]
328mod tests {
329 use std::sync::Mutex;
330
331 use bob_core::types::ToolDescriptor;
332 use tower::Service;
333
334 use super::*;
335
336 struct MockToolPort {
337 calls: Mutex<Vec<String>>,
338 }
339
340 impl MockToolPort {
341 fn new() -> Self {
342 Self { calls: Mutex::new(Vec::new()) }
343 }
344 }
345
346 #[async_trait::async_trait]
347 impl ToolPort for MockToolPort {
348 async fn list_tools(&self) -> Result<Vec<ToolDescriptor>, ToolError> {
349 Ok(vec![ToolDescriptor::new("mock/echo", "Echo tool")])
350 }
351
352 async fn call_tool(&self, call: ToolCall) -> Result<ToolResult, ToolError> {
353 let mut calls = self.calls.lock().unwrap_or_else(|p| p.into_inner());
354 calls.push(call.name.clone());
355 Ok(ToolResult {
356 name: call.name,
357 output: serde_json::json!({"ok": true}),
358 is_error: false,
359 })
360 }
361 }
362
363 #[tokio::test]
364 async fn tool_service_basic() {
365 let port: Arc<dyn ToolPort> = Arc::new(MockToolPort::new());
366 let mut svc = ToolService::new(port);
367
368 let resp = svc.call(ToolRequest::from_parts("mock/echo", serde_json::json!({}))).await;
369 assert!(resp.is_ok());
370 assert_eq!(resp.unwrap().result.name, "mock/echo");
371 }
372
373 #[tokio::test]
374 async fn tool_list_service() {
375 let port: Arc<dyn ToolPort> = Arc::new(MockToolPort::new());
376 let mut svc = ToolListService::new(port);
377
378 let tools = svc.call(ToolListRequest).await.expect("should list tools");
379 assert_eq!(tools.len(), 1);
380 assert_eq!(tools[0].id, "mock/echo");
381 }
382
383 #[tokio::test]
384 async fn service_ext_timeout() {
385 let port: Arc<dyn ToolPort> = Arc::new(MockToolPort::new());
386 let svc = ToolService::new(port);
387
388 let mut timeout_svc = svc.with_timeout(std::time::Duration::from_secs(1));
389
390 let resp =
391 timeout_svc.call(ToolRequest::from_parts("mock/echo", serde_json::json!({}))).await;
392 assert!(resp.is_ok());
393 }
394
395 #[tokio::test]
396 async fn port_service_ext() {
397 let port: Arc<dyn ToolPort> = Arc::new(MockToolPort::new());
398 let mut svc = port.into_tool_service();
399
400 let resp = svc.call(ToolRequest::from_parts("mock/echo", serde_json::json!({}))).await;
401 assert!(resp.is_ok());
402 }
403}