Skip to main content

gestalt/
agent.rs

1use std::sync::Arc;
2
3use hyper_util::rt::TokioIo;
4use tokio::net::UnixStream;
5use tonic::codegen::async_trait;
6use tonic::metadata::MetadataValue;
7use tonic::service::Interceptor;
8use tonic::service::interceptor::InterceptedService;
9use tonic::transport::{Channel, ClientTlsConfig, Endpoint, Uri};
10use tonic::{Request as GrpcRequest, Response as GrpcResponse, Status};
11use tower::service_fn;
12
13use crate::api::RuntimeMetadata;
14use crate::error::Result as ProviderResult;
15use crate::generated::v1::{
16    self as pb, agent_host_client::AgentHostClient as ProtoAgentHostClient,
17};
18
19type AgentHostTransport = InterceptedService<Channel, AgentHostRelayTokenInterceptor>;
20
21/// Environment variable containing the agent-host service target.
22pub const ENV_AGENT_HOST_SOCKET: &str = "GESTALT_AGENT_HOST_SOCKET";
23/// Environment variable containing the optional agent-host relay token.
24pub const ENV_AGENT_HOST_SOCKET_TOKEN: &str = "GESTALT_AGENT_HOST_SOCKET_TOKEN";
25const AGENT_HOST_RELAY_TOKEN_HEADER: &str = "x-gestalt-host-service-relay-token";
26
27#[derive(Debug, thiserror::Error)]
28/// Errors returned by [`AgentHost`].
29pub enum AgentHostError {
30    /// The host-service transport could not be created.
31    #[error("{0}")]
32    Transport(#[from] tonic::transport::Error),
33    /// The host-service RPC returned a gRPC status.
34    #[error("{0}")]
35    Status(#[from] tonic::Status),
36    /// Required environment or target configuration was invalid.
37    #[error("{0}")]
38    Env(String),
39}
40
41/// Client for the agent host service available inside agent providers.
42pub struct AgentHost {
43    client: ProtoAgentHostClient<AgentHostTransport>,
44}
45
46impl AgentHost {
47    /// Connects to the agent host service described by the environment.
48    pub async fn connect() -> std::result::Result<Self, AgentHostError> {
49        let target = std::env::var(ENV_AGENT_HOST_SOCKET)
50            .map_err(|_| AgentHostError::Env(format!("{ENV_AGENT_HOST_SOCKET} is not set")))?;
51        let relay_token = std::env::var(ENV_AGENT_HOST_SOCKET_TOKEN).unwrap_or_default();
52        let channel = match parse_agent_host_target(&target)? {
53            AgentHostTarget::Unix(path) => connect_unix(path).await?,
54            AgentHostTarget::Tcp(address) => {
55                Endpoint::from_shared(format!("http://{address}"))?
56                    .connect()
57                    .await?
58            }
59            AgentHostTarget::Tls(address) => {
60                Endpoint::from_shared(format!("https://{address}"))?
61                    .tls_config(ClientTlsConfig::new().with_native_roots())?
62                    .connect()
63                    .await?
64            }
65        };
66        Ok(Self {
67            client: ProtoAgentHostClient::with_interceptor(
68                channel,
69                agent_host_relay_token_interceptor(relay_token.trim())?,
70            ),
71        })
72    }
73
74    /// Executes a host tool using an agent protocol request message.
75    pub async fn execute_tool(
76        &mut self,
77        request: pb::ExecuteAgentToolRequest,
78    ) -> std::result::Result<pb::ExecuteAgentToolResponse, AgentHostError> {
79        Ok(self.client.execute_tool(request).await?.into_inner())
80    }
81
82    /// Lists host tools visible to the current agent request.
83    pub async fn list_tools(
84        &mut self,
85        request: pb::ListAgentToolsRequest,
86    ) -> std::result::Result<pb::ListAgentToolsResponse, AgentHostError> {
87        Ok(self.client.list_tools(request).await?.into_inner())
88    }
89
90    /// Resolves a configured agent connection for the current turn.
91    pub async fn resolve_connection(
92        &mut self,
93        request: pb::ResolveAgentConnectionRequest,
94    ) -> std::result::Result<pb::ResolvedAgentConnection, AgentHostError> {
95        Ok(self.client.resolve_connection(request).await?.into_inner())
96    }
97}
98
99async fn connect_unix(
100    socket_path: String,
101) -> std::result::Result<Channel, tonic::transport::Error> {
102    Endpoint::try_from("http://[::]:50051")?
103        .connect_with_connector(service_fn(move |_: Uri| {
104            let path = socket_path.clone();
105            async move { UnixStream::connect(path).await.map(TokioIo::new) }
106        }))
107        .await
108}
109
110#[derive(Clone)]
111struct AgentHostRelayTokenInterceptor {
112    token: Option<MetadataValue<tonic::metadata::Ascii>>,
113}
114
115impl Interceptor for AgentHostRelayTokenInterceptor {
116    fn call(
117        &mut self,
118        mut request: tonic::Request<()>,
119    ) -> std::result::Result<tonic::Request<()>, tonic::Status> {
120        if let Some(token) = self.token.clone() {
121            request
122                .metadata_mut()
123                .insert(AGENT_HOST_RELAY_TOKEN_HEADER, token);
124        }
125        Ok(request)
126    }
127}
128
129fn agent_host_relay_token_interceptor(
130    token: &str,
131) -> std::result::Result<AgentHostRelayTokenInterceptor, AgentHostError> {
132    let trimmed = token.trim();
133    let token = if trimmed.is_empty() {
134        None
135    } else {
136        Some(MetadataValue::try_from(trimmed).map_err(|err| {
137            AgentHostError::Env(format!("agent host: invalid relay token metadata: {err}"))
138        })?)
139    };
140    Ok(AgentHostRelayTokenInterceptor { token })
141}
142
143enum AgentHostTarget {
144    Unix(String),
145    Tcp(String),
146    Tls(String),
147}
148
149fn parse_agent_host_target(raw: &str) -> std::result::Result<AgentHostTarget, AgentHostError> {
150    let target = raw.trim();
151    if target.is_empty() {
152        return Err(AgentHostError::Env(
153            "agent host: transport target is required".to_string(),
154        ));
155    }
156    if let Some(address) = target.strip_prefix("tcp://") {
157        let address = address.trim();
158        if address.is_empty() {
159            return Err(AgentHostError::Env(format!(
160                "agent host: tcp target {raw:?} is missing host:port"
161            )));
162        }
163        return Ok(AgentHostTarget::Tcp(address.to_string()));
164    }
165    if let Some(address) = target.strip_prefix("tls://") {
166        let address = address.trim();
167        if address.is_empty() {
168            return Err(AgentHostError::Env(format!(
169                "agent host: tls target {raw:?} is missing host:port"
170            )));
171        }
172        return Ok(AgentHostTarget::Tls(address.to_string()));
173    }
174    if let Some(path) = target.strip_prefix("unix://") {
175        let path = path.trim();
176        if path.is_empty() {
177            return Err(AgentHostError::Env(format!(
178                "agent host: unix target {raw:?} is missing a socket path"
179            )));
180        }
181        return Ok(AgentHostTarget::Unix(path.to_string()));
182    }
183    if target.contains("://") {
184        return Err(AgentHostError::Env(format!(
185            "agent host: unsupported target scheme in {raw:?}"
186        )));
187    }
188    Ok(AgentHostTarget::Unix(target.to_string()))
189}
190
191#[async_trait]
192/// Provider trait for serving the Gestalt agent-provider protocol.
193pub trait AgentProvider: pb::agent_provider_server::AgentProvider + Send + Sync + 'static {
194    /// Configures the provider before it starts serving requests.
195    async fn configure(
196        &self,
197        _name: &str,
198        _config: serde_json::Map<String, serde_json::Value>,
199    ) -> ProviderResult<()> {
200        Ok(())
201    }
202
203    /// Returns runtime metadata that should augment the static manifest.
204    fn metadata(&self) -> Option<RuntimeMetadata> {
205        None
206    }
207
208    /// Returns non-fatal warnings the host should surface to users.
209    fn warnings(&self) -> Vec<String> {
210        Vec::new()
211    }
212
213    /// Performs an optional health check.
214    async fn health_check(&self) -> ProviderResult<()> {
215        Ok(())
216    }
217
218    /// Starts provider-owned background work after configuration.
219    async fn start(&self) -> ProviderResult<()> {
220        Ok(())
221    }
222
223    /// Shuts the provider down before the runtime exits.
224    async fn close(&self) -> ProviderResult<()> {
225        Ok(())
226    }
227}
228
229#[derive(Clone)]
230pub(crate) struct AgentServer<P> {
231    provider: Arc<P>,
232}
233
234impl<P> AgentServer<P> {
235    pub(crate) fn new(provider: Arc<P>) -> Self {
236        Self { provider }
237    }
238}
239
240#[async_trait]
241impl<P> pb::agent_provider_server::AgentProvider for AgentServer<P>
242where
243    P: AgentProvider,
244{
245    async fn create_session(
246        &self,
247        request: GrpcRequest<pb::CreateAgentProviderSessionRequest>,
248    ) -> std::result::Result<GrpcResponse<pb::AgentSession>, Status> {
249        self.provider.create_session(request).await
250    }
251
252    async fn get_session(
253        &self,
254        request: GrpcRequest<pb::GetAgentProviderSessionRequest>,
255    ) -> std::result::Result<GrpcResponse<pb::AgentSession>, Status> {
256        self.provider.get_session(request).await
257    }
258
259    async fn list_sessions(
260        &self,
261        request: GrpcRequest<pb::ListAgentProviderSessionsRequest>,
262    ) -> std::result::Result<GrpcResponse<pb::ListAgentProviderSessionsResponse>, Status> {
263        self.provider.list_sessions(request).await
264    }
265
266    async fn update_session(
267        &self,
268        request: GrpcRequest<pb::UpdateAgentProviderSessionRequest>,
269    ) -> std::result::Result<GrpcResponse<pb::AgentSession>, Status> {
270        self.provider.update_session(request).await
271    }
272
273    async fn create_turn(
274        &self,
275        request: GrpcRequest<pb::CreateAgentProviderTurnRequest>,
276    ) -> std::result::Result<GrpcResponse<pb::AgentTurn>, Status> {
277        self.provider.create_turn(request).await
278    }
279
280    async fn get_turn(
281        &self,
282        request: GrpcRequest<pb::GetAgentProviderTurnRequest>,
283    ) -> std::result::Result<GrpcResponse<pb::AgentTurn>, Status> {
284        self.provider.get_turn(request).await
285    }
286
287    async fn list_turns(
288        &self,
289        request: GrpcRequest<pb::ListAgentProviderTurnsRequest>,
290    ) -> std::result::Result<GrpcResponse<pb::ListAgentProviderTurnsResponse>, Status> {
291        self.provider.list_turns(request).await
292    }
293
294    async fn cancel_turn(
295        &self,
296        request: GrpcRequest<pb::CancelAgentProviderTurnRequest>,
297    ) -> std::result::Result<GrpcResponse<pb::AgentTurn>, Status> {
298        self.provider.cancel_turn(request).await
299    }
300
301    async fn list_turn_events(
302        &self,
303        request: GrpcRequest<pb::ListAgentProviderTurnEventsRequest>,
304    ) -> std::result::Result<GrpcResponse<pb::ListAgentProviderTurnEventsResponse>, Status> {
305        self.provider.list_turn_events(request).await
306    }
307
308    async fn get_interaction(
309        &self,
310        request: GrpcRequest<pb::GetAgentProviderInteractionRequest>,
311    ) -> std::result::Result<GrpcResponse<pb::AgentInteraction>, Status> {
312        self.provider.get_interaction(request).await
313    }
314
315    async fn list_interactions(
316        &self,
317        request: GrpcRequest<pb::ListAgentProviderInteractionsRequest>,
318    ) -> std::result::Result<GrpcResponse<pb::ListAgentProviderInteractionsResponse>, Status> {
319        self.provider.list_interactions(request).await
320    }
321
322    async fn resolve_interaction(
323        &self,
324        request: GrpcRequest<pb::ResolveAgentProviderInteractionRequest>,
325    ) -> std::result::Result<GrpcResponse<pb::AgentInteraction>, Status> {
326        self.provider.resolve_interaction(request).await
327    }
328
329    async fn get_capabilities(
330        &self,
331        request: GrpcRequest<pb::GetAgentProviderCapabilitiesRequest>,
332    ) -> std::result::Result<GrpcResponse<pb::AgentProviderCapabilities>, Status> {
333        self.provider.get_capabilities(request).await
334    }
335}