import torch
class Foo(torch.jit.ScriptModule):
def __init__(self, v):
super(Foo, self).__init__()
self.register_buffer('value', v)
@torch.jit.script_method
def forward(self, x, y):
return 2 * x + y + self.value
foo = Foo(torch.Tensor([42.0]))
foo.save('foo.pt')
class Foo1(torch.jit.ScriptModule):
def __init__(self):
super(Foo1, self).__init__()
def forward(self, x, y):
return 2 * x + y
foo = Foo1()
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
traced_foo.save('foo1.pt')
class Foo2(torch.jit.ScriptModule):
def __init__(self):
super(Foo2, self).__init__()
def forward(self, x, y):
return (2 * x + y, x - y)
foo = Foo2()
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
traced_foo.save('foo2.pt')
class Foo3(torch.jit.ScriptModule):
def __init__(self):
super(Foo3, self).__init__()
@torch.jit.script_method
def forward(self, x):
result = x[0]
for i in range(x.size(0)):
if i: result = result * x[i]
return result
foo = Foo3()
foo.save('foo3.pt')
from typing import Tuple, List
class Foo4(torch.jit.ScriptModule):
def __init__(self):
super(Foo4, self).__init__()
@torch.jit.script_method
def forward(self, x: Tuple[float, float, int]):
return x[0] + x[1] * x[2]
foo = Foo4()
foo.save('foo4.pt')
class Foo5(torch.jit.ScriptModule):
def __init__(self):
super(Foo5, self).__init__()
@torch.jit.script_method
def forward(self, xs: List[str]):
return [x[:-1] for x in xs]
foo = Foo5()
foo.save('foo5.pt')